Skip to content

Commit

Permalink
Synthesise Mirror.Sum for nested hierarchies
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Mar 29, 2021
1 parent 5bb0e92 commit 03a3115
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 8 deletions.
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,9 @@ object SymDenotations {

annotations.collect { case Annotation.Child(child) => child }.reverse
end children

def subclasses(using Context): List[Symbol] =
children.flatMap(c => if c.is(Sealed) then c.children else List(c)).sortBy(_.span.start)
}

/** The contents of a class definition during a period
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object SymUtils:
* - none of its children are anonymous classes
* - all of its children are addressable through a path from the parent class
* and also the location of the generated mirror.
* - all of its children are generic products or singletons
* - all of its children are generic products, singletons, or generic sums
*/
def whyNotGenericSum(declScope: Symbol)(using Context): String =
if (!self.is(Sealed))
Expand All @@ -113,7 +113,12 @@ object SymUtils:
if (child == self) "it has anonymous or inaccessible subclasses"
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
else if (!child.isClass) ""
else {
else if (child.isGenericProduct) ""
else if (child.is(Sealed)) {
val s = child.whyNotGenericSum(declScope)
if (s.isEmpty) s
else i"its child $child is not a generic sum because $s"
} else {
val s = child.whyNotGenericProduct
if (s.isEmpty) s
else i"its child $child is not a generic product because $s"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
else {
val cases =
for ((child, idx) <- cls.children.zipWithIndex) yield {
for ((child, idx) <- cls.subclasses.zipWithIndex) yield {
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
Expand Down
9 changes: 4 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,10 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val useCompanion = cls.useCompanionAsMirror

if cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
val elemLabels = cls.subclasses.map(c => ConstantType(Constant(c.name.toString)))

def solve(sym: Symbol): Type = sym match
case caseClass: ClassSymbol =>
assert(caseClass.is(Case))
case caseClass: ClassSymbol if caseClass.is(Case) =>
if caseClass.is(Module) then
caseClass.sourceModule.termRef
else
Expand Down Expand Up @@ -313,11 +312,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
val elems = mirroredType.derivedLambdaType(
resType = TypeOps.nestedPairs(cls.children.map(solve))
resType = TypeOps.nestedPairs(cls.subclasses.map(solve))
)
(mkMirroredMonoType(mirroredType), elems)
case _ =>
val elems = TypeOps.nestedPairs(cls.children.map(solve))
val elems = TypeOps.nestedPairs(cls.subclasses.map(solve))
(mirroredType, elems)

val mirrorType =
Expand Down
68 changes: 68 additions & 0 deletions tests/run/i11050.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import scala.compiletime.*
import scala.deriving.*

sealed trait TreeValue

sealed trait SubLevel extends TreeValue

case class Leaf1(value: String) extends TreeValue
case class Leaf2(value: Int) extends SubLevel
case class Leaf3(value: Char) extends TreeValue

object Test:
val m = summon[Mirror.SumOf[TreeValue]]
given Show[TreeValue] = Show.derived[TreeValue]

def main(args: Array[String]) =
val leaf1 = Leaf1("1")
val leaf2 = Leaf2(2)
val leaf3 = Leaf3('3')

assertEq(List(leaf1, leaf2, leaf3).map(m.ordinal), List(0, 1, 2))
assertShow[TreeValue](leaf1, "[0] Leaf1(value = \"1\")")
assertShow[TreeValue](leaf2, "[1] Leaf2(value = 2)")
assertShow[TreeValue](leaf3, "[2] Leaf3(value = '3')")
end main

def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"Expected $obt == $exp")
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
end Test

trait Show[-T]:
def show(x: T): String

object Show:
given Show[Int] with { def show(x: Int) = s"$x" }
given Show[Char] with { def show(x: Char) = s"'$x'" }
given Show[String] with { def show(x: String) = s"$"$x$"" }

inline def show[T](x: T): String = summonInline[Show[T]].show(x)

transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
def show(x: T): String = inline ev match {
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
}
}

inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)

inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
inline (erasedValue[Labels], erasedValue[Elems]) match {
case _: (label *: labels, elem *: elems) =>
val value = show(x.productElement(n).asInstanceOf[elem])
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
case _: (EmptyTuple, EmptyTuple) =>
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
}

transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
inline erasedValue[Alts] match {
case _: (alt *: alts) =>
if (ord == n) summonFrom {
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
} else showCases[alts](n + 1)(x, ord)
case _: EmptyTuple => throw new MatchError(x)
}
end Show

0 comments on commit 03a3115

Please sign in to comment.