Skip to content

Commit

Permalink
compute mirror child types of a union
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed May 5, 2022
1 parent 16ed844 commit 65a0d8c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
40 changes: 28 additions & 12 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import transform.SyntheticMembers._
import util.Property
import annotation.{tailrec, constructorOnly}

import scala.collection.mutable

/** Synthesize terms for special classes */
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
import ast.tpd._
Expand Down Expand Up @@ -337,7 +339,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))

def solve(sym: Symbol): Type = sym match
def solve(target: Type)(sym: Symbol): Type = sym match
case childClass: ClassSymbol =>
assert(childClass.isOneOf(Case | Sealed))
if childClass.is(Module) then
Expand All @@ -348,36 +350,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
// Compute the the full child type by solving the subtype constraint
// `C[X1, ..., Xn] <: P`, where
//
// - P is the current `mirroredType`
// - P is the current `targetPart`
// - C is the child class, with type parameters X1, ..., Xn
//
// Contravariant type parameters are minimized, all other type parameters are maximized.
def instantiate(using Context) =
val poly = constrained(info, untpd.EmptyTree)._1
def instantiate(targetPart: Type)(using Context) =
val poly = constrained(info)
val resType = poly.finalResultType
val target = mirroredType match
case tp: HKTypeLambda => tp.resultType
case tp => tp
resType <:< target
resType <:< targetPart // record constraints
val tparams = poly.paramRefs
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
resType.substParams(poly, instanceTypes)
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))

def instantiateAll(using Context): Type =

// instantiate for each part of a union type, compute lub of the results
def loop(explore: List[Type], acc: mutable.ListBuffer[Type]): Type = explore match
case OrType(tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
case tp :: rest => loop(rest , acc += instantiate(tp))
case _ => TypeComparer.lub(acc.toList)

def instantiateLub(tp1: Type, tp2: Type): Type =
loop(tp1 :: tp2 :: Nil, new mutable.ListBuffer[Type])

target match
case OrType(tp1, tp2) => instantiateLub(tp1, tp2)
case _ => instantiate(target)

instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
case _ =>
childClass.typeRef
case child => child.termRef
end solve

val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
val target = mirroredType.resultType
val elems = mirroredType.derivedLambdaType(
resType = TypeOps.nestedPairs(cls.children.map(solve))
resType = TypeOps.nestedPairs(cls.children.map(solve(target)))
)
(mkMirroredMonoType(mirroredType), elems)
case _ =>
val elems = TypeOps.nestedPairs(cls.children.map(solve))
case target =>
val elems = TypeOps.nestedPairs(cls.children.map(solve(target)))
(mirroredType, elems)

val mirrorType =
Expand Down
43 changes: 43 additions & 0 deletions tests/pos/i13493.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import deriving.Mirror

sealed trait Box[T]
object Box

case class Child[T](t: T) extends Box[T]

object MirrorK1:
type Of[F[_]] = Mirror { type MirroredType[A] = F[A] }

def testSums =

val foo = summon[Mirror.Of[Option[Int] | Option[String]]]
summon[foo.MirroredElemTypes =:= (None.type, Some[Int] | Some[String])]

val bar = summon[Mirror.Of[Box[Int] | Box[String]]]
summon[bar.MirroredElemTypes =:= ((Child[Int] | Child[String]) *: EmptyTuple)]

val qux = summon[Mirror.Of[Option[Int | String]]]
summon[qux.MirroredElemTypes =:= (None.type, Some[Int | String])]

val bip = summon[Mirror.Of[Box[Int | String]]]
summon[bip.MirroredElemTypes =:= (Child[Int | String] *: EmptyTuple)]

val bap = summon[MirrorK1.Of[[X] =>> Box[X] | Box[Int] | Box[String]]]
summon[bap.MirroredElemTypes[Boolean] =:= ((Child[Boolean] | Child[Int] | Child[String]) *: EmptyTuple)]


def testProducts =
val foo = summon[Mirror.Of[Some[Int] | Some[String]]]
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bar = summon[Mirror.Of[Child[Int] | Child[String]]]
summon[bar.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val qux = summon[Mirror.Of[Some[Int | String]]]
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bip = summon[Mirror.Of[Child[Int | String]]]
summon[bip.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]

val bap = summon[MirrorK1.Of[[X] =>> Child[X] | Child[Int] | Child[String]]]
summon[bap.MirroredElemTypes[Boolean] =:= ((Boolean | Int | String) *: EmptyTuple)]

0 comments on commit 65a0d8c

Please sign in to comment.