From ee5a82f65f42236d142f0e4b588e4522d03af616 Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Wed, 10 Mar 2021 15:42:15 +0000 Subject: [PATCH] Synthesise Mirror.Sum for nested hierarchies --- .../dotty/tools/dotc/transform/SymUtils.scala | 8 +- .../dotc/transform/SyntheticMembers.scala | 2 +- .../dotty/tools/dotc/typer/Synthesizer.scala | 22 +-- .../dotty/tools/dotc/CompilationTests.scala | 1 + .../fatal-warnings/i11050.scala | 141 ++++++++++++++++++ 5 files changed, 160 insertions(+), 14 deletions(-) create mode 100644 tests/run-custom-args/fatal-warnings/i11050.scala diff --git a/compiler/src/dotty/tools/dotc/transform/SymUtils.scala b/compiler/src/dotty/tools/dotc/transform/SymUtils.scala index 4df1d75e93fa..691425bfb713 100644 --- a/compiler/src/dotty/tools/dotc/transform/SymUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/SymUtils.scala @@ -95,7 +95,7 @@ object SymUtils: * - none of its children are anonymous classes * - all of its children are addressable through a path from the parent class * and also the location of the generated mirror. - * - all of its children are generic products or singletons + * - all of its children are generic products, singletons, or generic sums themselves. */ def whyNotGenericSum(declScope: Symbol)(using Context): String = if (!self.is(Sealed)) @@ -118,7 +118,11 @@ object SymUtils: else { val s = child.whyNotGenericProduct if (s.isEmpty) s - else i"its child $child is not a generic product because $s" + else if (child.is(Sealed)) { + val s = child.whyNotGenericSum(if child.useCompanionAsMirror then child.linkedClass else ctx.owner) + if (s.isEmpty) s + else i"its child $child is not a generic sum because $s" + } else i"its child $child is not a generic product because $s" } } if (children.isEmpty) "it does not have subclasses" diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 817be9cba633..60125c5a2058 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -525,7 +525,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) { val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType)) CaseDef(pat, EmptyTree, Literal(Constant(idx))) } - Match(param, cases) + Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases) } /** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index f66000c714e5..a7ab6d3dded9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -285,12 +285,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString))) def solve(sym: Symbol): Type = sym match - case caseClass: ClassSymbol => - assert(caseClass.is(Case)) - if caseClass.is(Module) then - caseClass.sourceModule.termRef + case childClass: ClassSymbol => + assert(childClass.isOneOf(Case | Sealed)) + if childClass.is(Module) then + childClass.sourceModule.termRef else - caseClass.primaryConstructor.info match + childClass.primaryConstructor.info match case info: PolyType => // Compute the the full child type by solving the subtype constraint // `C[X1, ..., Xn] <: P`, where @@ -307,13 +307,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): case tp => tp resType <:< target val tparams = poly.paramRefs - val variances = caseClass.typeParams.map(_.paramVarianceSign) + val variances = childClass.typeParams.map(_.paramVarianceSign) val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) => TypeComparer.instanceType(tparam, fromBelow = variance < 0)) resType.substParams(poly, instanceTypes) - instantiate(using ctx.fresh.setExploreTyperState().setOwner(caseClass)) + instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass)) case _ => - caseClass.typeRef + childClass.typeRef case child => child.termRef end solve @@ -328,9 +328,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): (mirroredType, elems) val mirrorType = - mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal) - .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) - .refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels))) + mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal) + .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) + .refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels))) val mirrorRef = if useCompanion then companionPath(mirroredType, span) else anonymousMirror(monoType, ExtendsSumMirror, span) diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index a88f94565e32..d896afa3ea28 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -197,6 +197,7 @@ class CompilationTests { compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"), compileFile("tests/run-custom-args/defaults-serizaliable-no-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"), compileFilesInDir("tests/run-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")), + compileFilesInDir("tests/run-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")), compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes), compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init")) ).checkRuns() diff --git a/tests/run-custom-args/fatal-warnings/i11050.scala b/tests/run-custom-args/fatal-warnings/i11050.scala new file mode 100644 index 000000000000..f0bdd23031fa --- /dev/null +++ b/tests/run-custom-args/fatal-warnings/i11050.scala @@ -0,0 +1,141 @@ +import scala.compiletime.* +import scala.deriving.* + +object OriginalReport: + sealed trait TreeValue + sealed trait SubLevel extends TreeValue + case class Leaf1(value: String) extends TreeValue + case class Leaf2(value: Int) extends SubLevel + +// Variants from the initial failure in akka.event.LogEvent +object FromAkkaCB: + sealed trait A + sealed trait B extends A + sealed trait C extends A + case class D() extends B, C + case class E() extends C, B + +object FromAkkaCB2: + sealed trait A + sealed trait N extends A + case class B() extends A + case class C() extends A, N + +object FromAkkaCB3: + sealed trait A + case class B() extends A + case class C() extends A + class D extends C // ignored pattern: class extending a case class + +object NoUnreachableWarnings: + sealed trait Top + object Top + + final case class MiddleA() extends Top with Bottom + final case class MiddleB() extends Top with Bottom + final case class MiddleC() extends Top with Bottom + + sealed trait Bottom extends Top + +object FromAkkaCB4: + sealed trait LogEvent + object LogEvent + case class Error() extends LogEvent + class Error2() extends Error() with LogEventWithMarker // ignored pattern + case class Warning() extends LogEvent + sealed trait LogEventWithMarker extends LogEvent // must be defined late + +object FromAkkaCB4simpler: + sealed trait LogEvent + object LogEvent + case class Error() extends LogEvent + class Error2() extends LogEventWithMarker // not a case class + case class Warning() extends LogEvent + sealed trait LogEventWithMarker extends LogEvent + +object Test: + def main(args: Array[String]): Unit = + testOriginalReport() + testFromAkkaCB() + testFromAkkaCB2() + end main + + def testOriginalReport() = + import OriginalReport._ + val m = summon[Mirror.SumOf[TreeValue]] + given Show[TreeValue] = Show.derived[TreeValue] + val leaf1 = Leaf1("1") + val leaf2 = Leaf2(2) + + assertEq(List(leaf1, leaf2).map(m.ordinal), List(1, 0)) + assertShow[TreeValue](leaf1, "[1] Leaf1(value = \"1\")") + assertShow[TreeValue](leaf2, "[0] [0] Leaf2(value = 2)") + end testOriginalReport + + def testFromAkkaCB() = + import FromAkkaCB._ + val m = summon[Mirror.SumOf[A]] + given Show[A] = Show.derived[A] + val d = D() + val e = E() + + assertEq(List(d, e).map(m.ordinal), List(0, 0)) + assertShow[A](d, "[0] [0] D") + assertShow[A](e, "[0] [1] E") + end testFromAkkaCB + + def testFromAkkaCB2() = + import FromAkkaCB2._ + val m = summon[Mirror.SumOf[A]] + val n = summon[Mirror.SumOf[N]] + given Show[A] = Show.derived[A] + val b = B() + val c = C() + + assertEq(List(b, c).map(m.ordinal), List(1, 0)) + assertShow[A](b, "[1] B") + assertShow[A](c, "[0] [0] C") + end testFromAkkaCB2 + + def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"$obt != $exp (obtained != expected)") + def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s) +end Test + +trait Show[-T]: + def show(x: T): String + +object Show: + given Show[Int] with { def show(x: Int) = s"$x" } + given Show[Char] with { def show(x: Char) = s"'$x'" } + given Show[String] with { def show(x: String) = s"$"$x$"" } + + inline def show[T](x: T): String = summonInline[Show[T]].show(x) + + transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new { + def show(x: T): String = inline ev match { + case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m) + case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x)) + } + } + + transparent inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String = + constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x) + + transparent inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String = + inline (erasedValue[Labels], erasedValue[Elems]) match { + case _: (label *: labels, elem *: elems) => + val value = show(x.productElement(n).asInstanceOf[elem]) + showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x) + case _: (EmptyTuple, EmptyTuple) => + if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")") + } + + transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String = + inline erasedValue[Alts] match { + case _: (alt *: alts) => + if (ord == n) summonFrom { + case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt]) + } else showCases[alts](n + 1)(x, ord) + case _: EmptyTuple => throw new MatchError(x) + } +end Show