Skip to content

Commit

Permalink
Instantiate more type variables to hard unions
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Aug 29, 2022
1 parent 63344e7 commit c1f35aa
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 38 deletions.
39 changes: 31 additions & 8 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import typer.ProtoTypes.{newTypeVar, representedParamRef}
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet
import NullOpsDecorator.stripNull

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -627,8 +628,11 @@ trait ConstraintHandling {
* 1. If `inst` is a singleton type, or a union containing some singleton types,
* widen (all) the singleton type(s), provided the result is a subtype of `bound`.
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provided the result is a subtype of `bound`.
* 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type
* from above by an intersection of all common base types, provided the result
* is a subtype of `bound`.
* 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard
* union type (except for unions | Null, which are kept in the state they were).
* 3. Widen some irreducible applications of higher-kinded types to wildcard arguments
* (see @widenIrreducible).
* 4. Drop transparent traits from intersections (see @dropTransparentTraits).
Expand All @@ -641,10 +645,12 @@ trait ConstraintHandling {
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(using Context): Type =
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
def widenOr(tp: Type) =
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
if widenUnions then
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
else tp.hardenUnions

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
Expand All @@ -664,6 +670,23 @@ trait ConstraintHandling {
wideInst.dropRepeatedAnnot
end widenInferred

/** Convert all toplevel union types in `tp` to hard unions */
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
case tp: AndType =>
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
case tp: RefinedType =>
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
case tp: RecType =>
tp.rebind(tp.parent.hardenUnions)
case tp: HKTypeLambda =>
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
case tp: OrType =>
val tp1 = tp.stripNull
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
case _ =>
tp

/** The instance type of `param` in the current constraint (which contains `param`).
* If `fromBelow` is true, the instance type is the lub of the parameter's
* lower bounds; otherwise it is the glb of its upper bounds. However,
Expand All @@ -672,18 +695,18 @@ trait ConstraintHandling {
* The instance type is not allowed to contain references to types nested deeper
* than `maxLevel`.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int)(using Context): Type = {
val approx = approximation(param, fromBelow, maxLevel).simplified
if fromBelow then
val widened = widenInferred(approx, param)
val widened = widenInferred(approx, param, widenUnions)
// Widening can add extra constraints, in particular the widened type might
// be a type variable which is now instantiated to `param`, and therefore
// cannot be used as an instantiation of `param` without creating a loop.
// If that happens, we run `instanceType` again to find a new instantation.
// (we do not check for non-toplevel occurences: those should never occur
// since `addOneBound` disallows recursive lower bounds).
if constraint.occursAtToplevel(param, widened) then
instanceType(param, fromBelow, maxLevel)
instanceType(param, fromBelow, widenUnions, maxLevel)
else
widened
else
Expand Down
67 changes: 45 additions & 22 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -487,31 +487,54 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
// before splitting the LHS into its constituents. That way, the RHS variables are
// constraint by the hard union and can be instantiated to it. If we just split and add
// constrained by the hard union and can be instantiated to it. If we just split and add
// the two parts of the LHS separately to the constraint, the lower bound would become
// a soft union.
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
case _ => true

widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.
/** Mark toplevel type vars in `tp2` as hard in the current typerState */
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
case tvar: TypeVar if constraint.contains(tvar.origin) =>
state.hardVars += tvar
case tp2: TypeParamRef if constraint.contains(tp2) =>
hardenTypeVars(constraint.typeVarOfParam(tp2))
case tp2: AndOrType =>
hardenTypeVars(tp2.tp1)
hardenTypeVars(tp2.tp2)
case _ =>

val res = widenOK
|| joinOK
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1)
&& !joined
&& {
joined = true
try inFrozenGadt(recur(tp1.join, tp2))
finally joined = false
}
// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
// On the other hand, we could get a combinatorial explosion by applying such joins
// recursively, so we do it only once. See i14870.scala as a test case, which would
// loop for a very long time without the recursion brake.

if res && !tp1.isSoft then
// We use a heuristic here where every toplevel type variable on the right hand side
// is marked so that it converts all soft unions in its lower bound to hard unions
// before it is instantiated. The reason is that the union might have come from
// (decomposed and reconstituted) `tp1`. But of course there might be false positives
// where we also treat unions that come from elsewhere as hard unions. Or the constraint
// that created the union is ultimately thrown away, but the type variable will
// stay marked. So it is a coarse measure to take. But it works in the obvious cases.
hardenTypeVars(tp2)

res

case CapturingType(parent1, refs1) =>
if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK && sameBoxed(tp1, tp2, refs1)
Expand Down Expand Up @@ -2960,8 +2983,8 @@ object TypeComparer {
def subtypeCheckInProgress(using Context): Boolean =
comparing(_.subtypeCheckInProgress)

def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, maxLevel))
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, widenUnions, maxLevel))

def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.approximation(param, fromBelow, maxLevel))
Expand All @@ -2981,8 +3004,8 @@ object TypeComparer {
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
comparing(_.addToConstraint(tl, tvars))

def widenInferred(inst: Type, bound: Type)(using Context): Type =
comparing(_.widenInferred(inst, bound))
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
comparing(_.widenInferred(inst, bound, widenUnions))

def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropTransparentTraits(tp, bound))
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ object TypeOps:
override def apply(tp: Type): Type = tp match
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
val lo = TypeComparer.instanceType(
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
tp.origin,
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
widenUnions = tp.widenUnions)(using mapCtx)
val lo1 = apply(lo)
if (lo1 ne lo) lo1 else tp
case _ =>
Expand Down
19 changes: 16 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ object TyperState {

type LevelMap = SimpleIdentityMap[TypeVar, Integer]

opaque type Snapshot = (Constraint, TypeVars, LevelMap)
opaque type Snapshot = (Constraint, TypeVars, TypeVars, LevelMap)

extension (ts: TyperState)
def snapshot()(using Context): Snapshot =
(ts.constraint, ts.ownedVars, ts.upLevels)
(ts.constraint, ts.ownedVars, ts.hardVars, ts.upLevels)

def resetTo(state: Snapshot)(using Context): Unit =
val (constraint, ownedVars, upLevels) = state
val (constraint, ownedVars, hardVars, upLevels) = state
for tv <- ownedVars do
if !ts.ownedVars.contains(tv) then // tv has been instantiated
tv.resetInst(ts)
ts.constraint = constraint
ts.ownedVars = ownedVars
ts.hardVars = hardVars
ts.upLevels = upLevels
}

Expand Down Expand Up @@ -91,6 +92,14 @@ class TyperState() {
def ownedVars: TypeVars = myOwnedVars
def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs

/** The set of type variables `tv` such that, if `tv` is instantiated to
* its lower bound, top-level soft unions in the instance type are converted
* to hard unions instead of being widened in `widenOr`.
*/
private var myHardVars: TypeVars = _
def hardVars: TypeVars = myHardVars
def hardVars_=(tvs: TypeVars): Unit = myHardVars = tvs

private var upLevels: LevelMap = _

/** Initializes all fields except reporter, isCommittable, which need to be
Expand All @@ -103,6 +112,7 @@ class TyperState() {
this.myConstraint = constraint
this.previousConstraint = constraint
this.myOwnedVars = SimpleIdentitySet.empty
this.myHardVars = SimpleIdentitySet.empty
this.upLevels = SimpleIdentityMap.empty
this.isCommitted = false
this
Expand All @@ -114,6 +124,7 @@ class TyperState() {
val ts = TyperState().init(this, this.constraint)
.setReporter(reporter)
.setCommittable(committable)
ts.hardVars = this.hardVars
ts.upLevels = upLevels
ts

Expand Down Expand Up @@ -180,6 +191,7 @@ class TyperState() {
constr.println(i"committing $this to $targetState, fromConstr = $constraint, toConstr = ${targetState.constraint}")
if targetState.constraint eq previousConstraint then
targetState.constraint = constraint
targetState.hardVars = hardVars
if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar)
else
targetState.mergeConstraintWith(this)
Expand Down Expand Up @@ -238,6 +250,7 @@ class TyperState() {
val otherLos = other.lower(p)
val otherHis = other.upper(p)
val otherEntry = other.entry(p)
if that.hardVars.contains(tv) then this.myHardVars += tv
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
((otherEntry eq constraint.entry(p)) || otherEntry.match
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4714,12 +4714,16 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel)
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
instantiateWith(tp)

/** Widen unions when instantiating this variable in the current context? */
def widenUnions(using Context): Boolean =
!ctx.typerState.hardVars.contains(this)

/** For uninstantiated type variables: the entry in the constraint (either bounds or
* provisional instance value)
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1888,7 +1888,7 @@ class Namer { typer: Typer =>
TypeOps.simplify(tp.widenTermRefExpr,
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
case ctp: ConstantType if sym.isInlineVal => ctp
case tp => TypeComparer.widenInferred(tp, pt)
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val tparams = poly.paramRefs
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0)
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)
)
val instanceType = resType.substParams(poly, instanceTypes)
// this is broken in tests/run/i13332intersection.scala,
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2847,7 +2847,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if (ctx.mode.is(Mode.Pattern)) app1
else {
val elemTpes = elems.lazyZip(pts).map((elem, pt) =>
TypeComparer.widenInferred(elem.tpe, pt))
TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true))
val resTpe = TypeOps.nestedPairs(elemTpes)
app1.cast(resTpe)
}
Expand Down
25 changes: 25 additions & 0 deletions tests/pos/i14770.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
type UndefOr[A] = A | Unit

extension [A](maybe: UndefOr[A])
def foreach(f: A => Unit): Unit =
maybe match
case () => ()
case a: A => f(a)

trait Foo
trait Bar

object Baz:
var booBap: Foo | Bar = _

def z: UndefOr[Foo | Bar] = ???

@main
def main =
z.foreach(x => Baz.booBap = x)

def test[A](v: A | Unit): A | Unit = v
val x1 = test(5: Int | Unit)
val x2 = test(5: String | Int | Unit)
val _: Int | Unit = x1
val _: String | Int | Unit = x2

0 comments on commit c1f35aa

Please sign in to comment.