From fe8d39a62be5644340e72ed73e9c464fe059a15e Mon Sep 17 00:00:00 2001 From: odersky Date: Sat, 9 Jul 2022 14:09:41 +0200 Subject: [PATCH] Instantiate more type variables to hard unions Fixes #14770 --- .../tools/dotc/core/ConstraintHandling.scala | 31 ++++++++--- .../dotty/tools/dotc/core/TypeComparer.scala | 54 +++++++++++-------- .../src/dotty/tools/dotc/core/TypeOps.scala | 4 +- .../src/dotty/tools/dotc/core/Types.scala | 4 +- .../src/dotty/tools/dotc/typer/Namer.scala | 2 +- .../dotty/tools/dotc/typer/Synthesizer.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/pos/i14770.scala | 25 +++++++++ 8 files changed, 92 insertions(+), 32 deletions(-) create mode 100644 tests/pos/i14770.scala diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 6c9bd1bb6577..60e374bd6474 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -12,6 +12,7 @@ import config.Printers.typr import typer.ProtoTypes.{newTypeVar, representedParamRef} import UnificationDirection.* import NameKinds.AvoidNameKind +import NullOpsDecorator.stripNull /** Methods for adding constraints and solving them. * @@ -525,10 +526,12 @@ trait ConstraintHandling { * At this point we also drop the @Repeated annotation to avoid inferring type arguments with it, * as those could leak the annotation to users (see run/inferred-repeated-result). */ - def widenInferred(inst: Type, bound: Type)(using Context): Type = + def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type = def widenOr(tp: Type) = - val tpw = tp.widenUnion - if (tpw ne tp) && (tpw <:< bound) then tpw else tp + if widenUnions then + val tpw = tp.widenUnion + if (tpw ne tp) && (tpw <:< bound) then tpw else tp + else tp.hardenUnions def widenSingle(tp: Type) = val tpw = tp.widenSingletons @@ -548,16 +551,32 @@ trait ConstraintHandling { wideInst.dropRepeatedAnnot end widenInferred + extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match + case tp: AndType => + tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions) + case tp: RefinedType => + tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo) + case tp: RecType => + tp.rebind(tp.parent.hardenUnions) + case tp: HKTypeLambda => + tp.derivedLambdaType(resType = tp.resType.hardenUnions) + case tp: OrType => + val tp1 = tp.stripNull + if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType) + else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false) + case _ => + tp + /** The instance type of `param` in the current constraint (which contains `param`). * If `fromBelow` is true, the instance type is the lub of the parameter's * lower bounds; otherwise it is the glb of its upper bounds. However, * a lower bound instantiation can be a singleton type only if the upper bound * is also a singleton type. */ - def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = { + def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type = { val approx = approximation(param, fromBelow).simplified if fromBelow then - val widened = widenInferred(approx, param) + val widened = widenInferred(approx, param, widenUnions) // Widening can add extra constraints, in particular the widened type might // be a type variable which is now instantiated to `param`, and therefore // cannot be used as an instantiation of `param` without creating a loop. @@ -565,7 +584,7 @@ trait ConstraintHandling { // (we do not check for non-toplevel occurences: those should never occur // since `addOneBound` disallows recursive lower bounds). if constraint.occursAtToplevel(param, widened) then - instanceType(param, fromBelow) + instanceType(param, fromBelow, widenUnions) else widened else diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 6d79f377c84e..88702d50f825 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -492,23 +492,35 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22) case _ => true - widenOK - || joinOK - || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2) - || containsAnd(tp1) - && !joined - && { - joined = true - try inFrozenGadt(recur(tp1.join, tp2)) - finally joined = false - } - // An & on the left side loses information. We compensate by also trying the join. - // This is less ad-hoc than it looks since we produce joins in type inference, - // and then need to check that they are indeed supertypes of the original types - // under -Ycheck. Test case is i7965.scala. - // On the other hand, we could get a combinatorial explosion by applying such joins - // recursively, so we do it only once. See i14870.scala as a test case, which would - // loop for a very long time without the recursion brake. + def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match + case tvar: TypeVar if constraint.contains(tvar.origin) => + tvar.widenUnions = false + case tp2: TypeParamRef if constraint.contains(tp2) => + hardenTypeVars(constraint.typeVarOfParam(tp2)) + case tp2: AndOrType => + hardenTypeVars(tp2.tp1) + hardenTypeVars(tp2.tp2) + case _ => + + val res = widenOK + || joinOK + || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2) + || containsAnd(tp1) + && !joined + && { + joined = true + try inFrozenGadt(recur(tp1.join, tp2)) + finally joined = false + } + // An & on the left side loses information. We compensate by also trying the join. + // This is less ad-hoc than it looks since we produce joins in type inference, + // and then need to check that they are indeed supertypes of the original types + // under -Ycheck. Test case is i7965.scala. + // On the other hand, we could get a combinatorial explosion by applying such joins + // recursively, so we do it only once. See i14870.scala as a test case, which would + // loop for a very long time without the recursion brake. + if res then hardenTypeVars(tp2) + res case tp1: MatchType => val reduced = tp1.reduced @@ -2851,8 +2863,8 @@ object TypeComparer { def subtypeCheckInProgress(using Context): Boolean = comparing(_.subtypeCheckInProgress) - def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = - comparing(_.instanceType(param, fromBelow)) + def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type = + comparing(_.instanceType(param, fromBelow, widenUnions)) def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = comparing(_.approximation(param, fromBelow)) @@ -2872,8 +2884,8 @@ object TypeComparer { def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean = comparing(_.addToConstraint(tl, tvars)) - def widenInferred(inst: Type, bound: Type)(using Context): Type = - comparing(_.widenInferred(inst, bound)) + def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type = + comparing(_.widenInferred(inst, bound, widenUnions)) def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type = comparing(_.dropTransparentTraits(tp, bound)) diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index d12ecd8cd17f..509675947311 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -517,7 +517,9 @@ object TypeOps: override def apply(tp: Type): Type = tp match case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) => val lo = TypeComparer.instanceType( - tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx) + tp.origin, + fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound, + widenUnions = tp.widenUnions)(using mapCtx) val lo1 = apply(lo) if (lo1 ne lo) lo1 else tp case _ => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 2b51d64a2f27..b4d3341a117a 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4507,6 +4507,8 @@ object Types { final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState | Null, val nestingLevel: Int) extends CachedProxyType with ValueType { private var currentOrigin = initOrigin + var widenUnions = true + def origin: TypeParamRef = currentOrigin /** Set origin to new parameter. Called if we merge two conflicting constraints. @@ -4569,7 +4571,7 @@ object Types { * is also a singleton type. */ def instantiate(fromBelow: Boolean)(using Context): Type = - val tp = TypeComparer.instanceType(origin, fromBelow) + val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions) if myInst.exists then // The line above might have triggered instantiation of the current type variable myInst else diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 3a821f1dc65d..6b0712714c23 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1884,7 +1884,7 @@ class Namer { typer: Typer => TypeOps.simplify(tp.widenTermRefExpr, if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match case ctp: ConstantType if sym.isInlineVal => ctp - case tp => TypeComparer.widenInferred(tp, pt) + case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true) // Replace aliases to Unit by Unit itself. If we leave the alias in // it would be erased to BoxedUnit. diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 8b0fa88ad5a9..cc68ffc322da 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -489,7 +489,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val tparams = poly.paramRefs val variances = childClass.typeParams.map(_.paramVarianceSign) val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) => - TypeComparer.instanceType(tparam, fromBelow = variance < 0)) + TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)) resType.substParams(poly, instanceTypes) instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass)) case _ => diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index fa8cbaf7ff5a..07bf5d29c9e1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2808,7 +2808,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if (ctx.mode.is(Mode.Pattern)) app1 else { val elemTpes = elems.lazyZip(pts).map((elem, pt) => - TypeComparer.widenInferred(elem.tpe, pt)) + TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true)) val resTpe = TypeOps.nestedPairs(elemTpes) app1.cast(resTpe) } diff --git a/tests/pos/i14770.scala b/tests/pos/i14770.scala new file mode 100644 index 000000000000..182ccba21fdf --- /dev/null +++ b/tests/pos/i14770.scala @@ -0,0 +1,25 @@ +type UndefOr[A] = A | Unit + +extension [A](maybe: UndefOr[A]) + def foreach(f: A => Unit): Unit = + maybe match + case () => () + case a: A => f(a) + +trait Foo +trait Bar + +object Baz: + var booBap: Foo | Bar = _ + +def z: UndefOr[Foo | Bar] = ??? + +@main +def main = + z.foreach(x => Baz.booBap = x) + +def test[A](v: A | Unit): A | Unit = v +val x1 = test(5: Int | Unit) +val x2 = test(5: String | Int | Unit) +val _: Int | Unit = x1 +val _: String | Int | Unit = x2