Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Instantiate more type variables to hard unions
Browse files Browse the repository at this point in the history
odersky committed Jul 9, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 1724d84 commit cdce995
Showing 8 changed files with 92 additions and 32 deletions.
31 changes: 25 additions & 6 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import UnificationDirection.*
import NameKinds.AvoidNameKind
import NullOpsDecorator.stripNull

/** Methods for adding constraints and solving them.
*
@@ -525,10 +526,12 @@ trait ConstraintHandling {
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(using Context): Type =
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
def widenOr(tp: Type) =
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
@@ -548,24 +551,40 @@ trait ConstraintHandling {
wideInst.dropRepeatedAnnot
end widenInferred

extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
case tp: AndType =>
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
case tp: RefinedType =>
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
case tp: RecType =>
tp.rebind(tp.parent.hardenUnions)
case tp: HKTypeLambda =>
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
case tp: OrType =>
val tp1 = tp.stripNull
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
case _ =>
tp

/** The instance type of `param` in the current constraint (which contains `param`).
* If `fromBelow` is true, the instance type is the lub of the parameter's
* lower bounds; otherwise it is the glb of its upper bounds. However,
* a lower bound instantiation can be a singleton type only if the upper bound
* is also a singleton type.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type = {
val approx = approximation(param, fromBelow).simplified
if fromBelow then
val widened = widenInferred(approx, param)
val widened = widenInferred(approx, param, widenUnions)
// Widening can add extra constraints, in particular the widened type might
// be a type variable which is now instantiated to `param`, and therefore
// cannot be used as an instantiation of `param` without creating a loop.
// If that happens, we run `instanceType` again to find a new instantation.
// (we do not check for non-toplevel occurences: those should never occur
// since `addOneBound` disallows recursive lower bounds).
if constraint.occursAtToplevel(param, widened) then
instanceType(param, fromBelow)
instanceType(param, fromBelow, widenUnions)
else
widened
else
54 changes: 33 additions & 21 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
@@ -492,23 +492,35 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
case _ => true

widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
case tvar: TypeVar if constraint.contains(tvar.origin) =>
tvar.widenUnions = false
case tp2: TypeParamRef if constraint.contains(tp2) =>
hardenTypeVars(constraint.typeVarOfParam(tp2))
case tp2: AndOrType =>
hardenTypeVars(tp2.tp1)
hardenTypeVars(tp2.tp2)
case _ =>

val res = widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.
if res && !tp1.isSoft then hardenTypeVars(tp2)
res

case tp1: MatchType =>
val reduced = tp1.reduced
@@ -2851,8 +2863,8 @@ object TypeComparer {
def subtypeCheckInProgress(using Context): Boolean =
comparing(_.subtypeCheckInProgress)

def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.instanceType(param, fromBelow))
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type =
comparing(_.instanceType(param, fromBelow, widenUnions))

def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.approximation(param, fromBelow))
@@ -2872,8 +2884,8 @@ object TypeComparer {
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
comparing(_.addToConstraint(tl, tvars))

def widenInferred(inst: Type, bound: Type)(using Context): Type =
comparing(_.widenInferred(inst, bound))
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
comparing(_.widenInferred(inst, bound, widenUnions))

def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropTransparentTraits(tp, bound))
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
@@ -517,7 +517,9 @@ object TypeOps:
override def apply(tp: Type): Type = tp match
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
val lo = TypeComparer.instanceType(
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
tp.origin,
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
widenUnions = tp.widenUnions)(using mapCtx)
val lo1 = apply(lo)
if (lo1 ne lo) lo1 else tp
case _ =>
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
@@ -4507,6 +4507,8 @@ object Types {
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState | Null, val nestingLevel: Int) extends CachedProxyType with ValueType {
private var currentOrigin = initOrigin

var widenUnions = true

def origin: TypeParamRef = currentOrigin

/** Set origin to new parameter. Called if we merge two conflicting constraints.
@@ -4569,7 +4571,7 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow)
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
@@ -1884,7 +1884,7 @@ class Namer { typer: Typer =>
TypeOps.simplify(tp.widenTermRefExpr,
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
case ctp: ConstantType if sym.isInlineVal => ctp
case tp => TypeComparer.widenInferred(tp, pt)
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
@@ -489,7 +489,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val tparams = poly.paramRefs
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true))
resType.substParams(poly, instanceTypes)
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
case _ =>
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
@@ -2808,7 +2808,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if (ctx.mode.is(Mode.Pattern)) app1
else {
val elemTpes = elems.lazyZip(pts).map((elem, pt) =>
TypeComparer.widenInferred(elem.tpe, pt))
TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true))
val resTpe = TypeOps.nestedPairs(elemTpes)
app1.cast(resTpe)
}
25 changes: 25 additions & 0 deletions tests/pos/i14770.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
type UndefOr[A] = A | Unit

extension [A](maybe: UndefOr[A])
def foreach(f: A => Unit): Unit =
maybe match
case () => ()
case a: A => f(a)

trait Foo
trait Bar

object Baz:
var booBap: Foo | Bar = _

def z: UndefOr[Foo | Bar] = ???

@main
def main =
z.foreach(x => Baz.booBap = x)

def test[A](v: A | Unit): A | Unit = v
val x1 = test(5: Int | Unit)
val x2 = test(5: String | Int | Unit)
val _: Int | Unit = x1
val _: String | Int | Unit = x2

0 comments on commit cdce995

Please sign in to comment.