Skip to content

Commit

Permalink
fix 14823: handle unions in companion path
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed May 5, 2022
1 parent 06a8f22 commit 16ed844
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 24 deletions.
11 changes: 9 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Contexts._
import Symbols._
import Names.Name

import dotty.tools.dotc.core.Decorators.*

object TypeUtils {
/** A decorator that provides methods on types
* that are needed in the transformer pipeline.
Expand Down Expand Up @@ -84,11 +86,16 @@ object TypeUtils {
/** The TermRef referring to the companion of the underlying class reference
* of this type, while keeping the same prefix.
*/
def companionRef(using Context): TermRef = self match {
def mirrorCompanionRef(using Context): TermRef = self match {
case OrType(tp1, tp2) =>
val r1 = tp1.mirrorCompanionRef
val r2 = tp2.mirrorCompanionRef
assert(r1.symbol == r2.symbol, em"mirrorCompanionRef mismatch for $self: $r1, $r2 did not have the same symbol")
r1
case self @ TypeRef(prefix, _) if self.symbol.isClass =>
prefix.select(self.symbol.companionModule).asInstanceOf[TermRef]
case self: TypeProxy =>
self.underlying.companionRef
self.underlying.mirrorCompanionRef
}

/** Is this type a methodic type that takes implicit parameters (both old and new) at some point? */
Expand Down
55 changes: 33 additions & 22 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):

/** A path referencing the companion of class type `clsType` */
private def companionPath(clsType: Type, span: Span)(using Context) =
val ref = pathFor(clsType.companionRef)
val ref = pathFor(clsType.mirrorCompanionRef)
assert(ref.symbol.is(Module) && (clsType.classSymbol.is(ModuleClass) || (ref.symbol.companionClass == clsType.classSymbol)))
ref.withSpan(span)

Expand All @@ -275,6 +275,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
monoMap(mirroredType.resultType)

private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): Tree =

/** do all parts match the class symbol? */
def acceptable(tp: Type, cls: Symbol): Boolean = tp match
case tp: TypeProxy => acceptable(tp.underlying, cls)
case OrType(tp1, tp2) => acceptable(tp1, cls) && acceptable(tp2, cls)
case _ => tp.classSymbol eq cls

def makeProductMirror(cls: Symbol): Tree =
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
case _ =>
(mirroredType, nestedPairs)
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if (cls.is(Scala2x) || cls.linkedClass.is(Case)) anonymousMirror(monoType, ExtendsProductMirror, span)
else companionPath(mirroredType, span)
mirrorRef.cast(mirrorType)
end makeProductMirror

mirroredType match
case AndType(tp1, tp2) =>
productMirror(tp1, formal, span).orElse(productMirror(tp2, formal, span))
Expand All @@ -289,28 +318,10 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
else
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal)
modulePath.cast(mirrorType)
else if mirroredType.classSymbol.isGenericProduct then
else
val cls = mirroredType.classSymbol
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
case _ =>
(mirroredType, nestedPairs)
val elemsLabels = TypeOps.nestedPairs(elemLabels)
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if (cls.is(Scala2x) || cls.linkedClass.is(Case)) anonymousMirror(monoType, ExtendsProductMirror, span)
else companionPath(mirroredType, span)
mirrorRef.cast(mirrorType)
else EmptyTree
if acceptable(mirroredType, cls) && cls.isGenericProduct then makeProductMirror(cls)
else EmptyTree
end productMirror

private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): Tree =
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i14823.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i14823.scala:8:50 ----------------------------------------------------------------------------------
8 |val baz = summon[Mirror.Of[SubA[Int] | SubB[Int]]] // error
| ^
|no given instance of type deriving.Mirror.Of[SubA[Int] | SubB[Int]] was found for parameter x of method summon in object Predef
12 changes: 12 additions & 0 deletions tests/neg/i14823.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import deriving.Mirror

case class Cov[+T]()

class SubA[+T]() extends Cov[T]
class SubB[+T]() extends Cov[T]

val baz = summon[Mirror.Of[SubA[Int] | SubB[Int]]] // error
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// this should fail because:
// 1) SubA and SubB are not individually product types
// 2) SubA and SubB are different classes
18 changes: 18 additions & 0 deletions tests/pos/i14823.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import deriving.Mirror

object MirrorK1:
type Of[F[_]] = Mirror { type MirroredType[A] = F[A] }

sealed trait Box[T]
object Box

case class Child[T]() extends Box[T]

sealed abstract class Foo[T]
object Foo {
case class A[T]() extends Foo[T]
}

val foo = summon[Mirror.Of[Box[Int] | Box[Int]]]
val bar = summon[MirrorK1.Of[[X] =>> Box[Int] | Box[Int]]]
def baz = summon[deriving.Mirror.Of[Foo[String] | Foo[String]]]

0 comments on commit 16ed844

Please sign in to comment.