From 16ed84475d3e187c89b1aa5b66fb6281ec9eb54d Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 22 Apr 2022 10:31:28 +0200 Subject: [PATCH] fix 14823: handle unions in companion path --- .../tools/dotc/transform/TypeUtils.scala | 11 +++- .../dotty/tools/dotc/typer/Synthesizer.scala | 55 +++++++++++-------- tests/neg/i14823.check | 4 ++ tests/neg/i14823.scala | 12 ++++ tests/pos/i14823.scala | 18 ++++++ 5 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 tests/neg/i14823.check create mode 100644 tests/neg/i14823.scala create mode 100644 tests/pos/i14823.scala diff --git a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala index 7a3da6ad4bde..709635630254 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -9,6 +9,8 @@ import Contexts._ import Symbols._ import Names.Name +import dotty.tools.dotc.core.Decorators.* + object TypeUtils { /** A decorator that provides methods on types * that are needed in the transformer pipeline. @@ -84,11 +86,16 @@ object TypeUtils { /** The TermRef referring to the companion of the underlying class reference * of this type, while keeping the same prefix. */ - def companionRef(using Context): TermRef = self match { + def mirrorCompanionRef(using Context): TermRef = self match { + case OrType(tp1, tp2) => + val r1 = tp1.mirrorCompanionRef + val r2 = tp2.mirrorCompanionRef + assert(r1.symbol == r2.symbol, em"mirrorCompanionRef mismatch for $self: $r1, $r2 did not have the same symbol") + r1 case self @ TypeRef(prefix, _) if self.symbol.isClass => prefix.select(self.symbol.companionModule).asInstanceOf[TermRef] case self: TypeProxy => - self.underlying.companionRef + self.underlying.mirrorCompanionRef } /** Is this type a methodic type that takes implicit parameters (both old and new) at some point? */ diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 03554cbabea6..b2afc8ba66c8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -249,7 +249,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): /** A path referencing the companion of class type `clsType` */ private def companionPath(clsType: Type, span: Span)(using Context) = - val ref = pathFor(clsType.companionRef) + val ref = pathFor(clsType.mirrorCompanionRef) assert(ref.symbol.is(Module) && (clsType.classSymbol.is(ModuleClass) || (ref.symbol.companionClass == clsType.classSymbol))) ref.withSpan(span) @@ -275,6 +275,35 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): monoMap(mirroredType.resultType) private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): Tree = + + /** do all parts match the class symbol? */ + def acceptable(tp: Type, cls: Symbol): Boolean = tp match + case tp: TypeProxy => acceptable(tp.underlying, cls) + case OrType(tp1, tp2) => acceptable(tp1, cls) && acceptable(tp2, cls) + case _ => tp.classSymbol eq cls + + def makeProductMirror(cls: Symbol): Tree = + val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal)) + val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) + val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) + val (monoType, elemsType) = mirroredType match + case mirroredType: HKTypeLambda => + (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs)) + case _ => + (mirroredType, nestedPairs) + val elemsLabels = TypeOps.nestedPairs(elemLabels) + checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span) + checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span) + val mirrorType = + mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal) + .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) + .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels)) + val mirrorRef = + if (cls.is(Scala2x) || cls.linkedClass.is(Case)) anonymousMirror(monoType, ExtendsProductMirror, span) + else companionPath(mirroredType, span) + mirrorRef.cast(mirrorType) + end makeProductMirror + mirroredType match case AndType(tp1, tp2) => productMirror(tp1, formal, span).orElse(productMirror(tp2, formal, span)) @@ -289,28 +318,10 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): else val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal) modulePath.cast(mirrorType) - else if mirroredType.classSymbol.isGenericProduct then + else val cls = mirroredType.classSymbol - val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal)) - val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) - val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) - val (monoType, elemsType) = mirroredType match - case mirroredType: HKTypeLambda => - (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs)) - case _ => - (mirroredType, nestedPairs) - val elemsLabels = TypeOps.nestedPairs(elemLabels) - checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span) - checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span) - val mirrorType = - mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal) - .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) - .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels)) - val mirrorRef = - if (cls.is(Scala2x) || cls.linkedClass.is(Case)) anonymousMirror(monoType, ExtendsProductMirror, span) - else companionPath(mirroredType, span) - mirrorRef.cast(mirrorType) - else EmptyTree + if acceptable(mirroredType, cls) && cls.isGenericProduct then makeProductMirror(cls) + else EmptyTree end productMirror private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): Tree = diff --git a/tests/neg/i14823.check b/tests/neg/i14823.check new file mode 100644 index 000000000000..3de29b3845b7 --- /dev/null +++ b/tests/neg/i14823.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/i14823.scala:8:50 ---------------------------------------------------------------------------------- +8 |val baz = summon[Mirror.Of[SubA[Int] | SubB[Int]]] // error + | ^ + |no given instance of type deriving.Mirror.Of[SubA[Int] | SubB[Int]] was found for parameter x of method summon in object Predef diff --git a/tests/neg/i14823.scala b/tests/neg/i14823.scala new file mode 100644 index 000000000000..21e836564777 --- /dev/null +++ b/tests/neg/i14823.scala @@ -0,0 +1,12 @@ +import deriving.Mirror + +case class Cov[+T]() + +class SubA[+T]() extends Cov[T] +class SubB[+T]() extends Cov[T] + +val baz = summon[Mirror.Of[SubA[Int] | SubB[Int]]] // error +// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +// this should fail because: +// 1) SubA and SubB are not individually product types +// 2) SubA and SubB are different classes diff --git a/tests/pos/i14823.scala b/tests/pos/i14823.scala new file mode 100644 index 000000000000..01b8a76095d5 --- /dev/null +++ b/tests/pos/i14823.scala @@ -0,0 +1,18 @@ +import deriving.Mirror + +object MirrorK1: + type Of[F[_]] = Mirror { type MirroredType[A] = F[A] } + +sealed trait Box[T] +object Box + +case class Child[T]() extends Box[T] + +sealed abstract class Foo[T] +object Foo { + case class A[T]() extends Foo[T] +} + +val foo = summon[Mirror.Of[Box[Int] | Box[Int]]] +val bar = summon[MirrorK1.Of[[X] =>> Box[Int] | Box[Int]]] +def baz = summon[deriving.Mirror.Of[Foo[String] | Foo[String]]]