Skip to content

Commit

Permalink
Synthesise Mirror.Sum for nested hierarchies
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Jun 2, 2021
1 parent 8f3fdf5 commit 3665eba
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 13 deletions.
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object SymUtils:
* - none of its children are anonymous classes
* - all of its children are addressable through a path from the parent class
* and also the location of the generated mirror.
* - all of its children are generic products or singletons
* - all of its children are generic products, singletons, or generic sums themselves.
*/
def whyNotGenericSum(declScope: Symbol)(using Context): String =
if (!self.is(Sealed))
Expand All @@ -116,7 +116,11 @@ object SymUtils:
else {
val s = child.whyNotGenericProduct
if (s.isEmpty) s
else i"its child $child is not a generic product because $s"
else if (child.is(Sealed)) {
val s = child.whyNotGenericSum(if child.useCompanionAsMirror then child.linkedClass else child.owner)
if (s.isEmpty) s
else i"its child $child is not a generic sum because $s"
} else i"its child $child is not a generic product because $s"
}
}
if (children.isEmpty) "it does not have subclasses"
Expand Down
22 changes: 11 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))

def solve(sym: Symbol): Type = sym match
case caseClass: ClassSymbol =>
assert(caseClass.is(Case))
if caseClass.is(Module) then
caseClass.sourceModule.termRef
case childClass: ClassSymbol =>
assert(childClass.isOneOf(Case | Sealed))
if childClass.is(Module) then
childClass.sourceModule.termRef
else
caseClass.primaryConstructor.info match
childClass.primaryConstructor.info match
case info: PolyType =>
// Compute the the full child type by solving the subtype constraint
// `C[X1, ..., Xn] <: P`, where
Expand All @@ -310,13 +310,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
case tp => tp
resType <:< target
val tparams = poly.paramRefs
val variances = caseClass.typeParams.map(_.paramVarianceSign)
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
resType.substParams(poly, instanceTypes)
instantiate(using ctx.fresh.setExploreTyperState().setOwner(caseClass))
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
case _ =>
caseClass.typeRef
childClass.typeRef
case child => child.termRef
end solve

Expand All @@ -331,9 +331,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
(mirroredType, elems)

val mirrorType =
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
val mirrorRef =
if useCompanion then companionPath(mirroredType, span)
else anonymousMirror(monoType, ExtendsSumMirror, span)
Expand Down
129 changes: 129 additions & 0 deletions tests/run/i11050.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import scala.compiletime.*
import scala.deriving.*

object OriginalReport:
sealed trait TreeValue
sealed trait SubLevel extends TreeValue
case class Leaf1(value: String) extends TreeValue
case class Leaf2(value: Int) extends SubLevel

// Variants from the initial failure in akka.event.LogEvent
object FromAkkaCB:
sealed trait A
sealed trait B extends A
sealed trait C extends A
case class D() extends B, C
case class E() extends C, B

object FromAkkaCB2:
sealed trait A
sealed trait N extends A
case class B() extends A
case class C() extends A, N

object FromAkkaCB3:
sealed trait A
case class B() extends A
case class C() extends A
class D extends C // ignored pattern: class extending a case class

object FromAkkaCB4:
sealed trait A
sealed trait N extends A
case class B() extends A
case class C() extends A
class D extends C, N // ignored

object FromAkkaCB5:
sealed trait A
sealed trait N extends A
case class B() extends A
case class C() extends A
class D extends N // ignored

object Test:
def main(args: Array[String]): Unit =
testOriginalReport()
testFromAkkaCB()
testFromAkkaCB2()
end main

def testOriginalReport() =
import OriginalReport._
val m = summon[Mirror.SumOf[TreeValue]]
given Show[TreeValue] = Show.derived[TreeValue]
val leaf1 = Leaf1("1")
val leaf2 = Leaf2(2)

assertEq(List(leaf1, leaf2).map(m.ordinal), List(1, 0))
assertShow[TreeValue](leaf1, "[1] Leaf1(value = \"1\")")
assertShow[TreeValue](leaf2, "[0] [0] Leaf2(value = 2)")
end testOriginalReport

def testFromAkkaCB() =
import FromAkkaCB._
val m = summon[Mirror.SumOf[A]]
given Show[A] = Show.derived[A]
val d = D()
val e = E()

assertEq(List(d, e).map(m.ordinal), List(0, 0))
assertShow[A](d, "[0] [0] D")
assertShow[A](e, "[0] [1] E")
end testFromAkkaCB

def testFromAkkaCB2() =
import FromAkkaCB2._
val m = summon[Mirror.SumOf[A]]
val n = summon[Mirror.SumOf[N]]
given Show[A] = Show.derived[A]
val b = B()
val c = C()

assertEq(List(b, c).map(m.ordinal), List(1, 0))
assertShow[A](b, "[1] B")
assertShow[A](c, "[0] [0] C")
end testFromAkkaCB2

def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"$obt != $exp (obtained != expected)")
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
end Test

trait Show[-T]:
def show(x: T): String

object Show:
given Show[Int] with { def show(x: Int) = s"$x" }
given Show[Char] with { def show(x: Char) = s"'$x'" }
given Show[String] with { def show(x: String) = s"$"$x$"" }

inline def show[T](x: T): String = summonInline[Show[T]].show(x)

transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
def show(x: T): String = inline ev match {
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
}
}

transparent inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)

transparent inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
inline (erasedValue[Labels], erasedValue[Elems]) match {
case _: (label *: labels, elem *: elems) =>
val value = show(x.productElement(n).asInstanceOf[elem])
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
case _: (EmptyTuple, EmptyTuple) =>
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
}

transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
inline erasedValue[Alts] match {
case _: (alt *: alts) =>
if (ord == n) summonFrom {
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
} else showCases[alts](n + 1)(x, ord)
case _: EmptyTuple => throw new MatchError(x)
}
end Show

0 comments on commit 3665eba

Please sign in to comment.