Skip to content

Commit

Permalink
Avoid instantiating variables to types with deeper nesting levels
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Aug 18, 2022
1 parent d670018 commit 824496b
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 34 deletions.
107 changes: 93 additions & 14 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -74,8 +75,41 @@ trait ConstraintHandling {
protected def necessaryConstraintsOnly(using Context): Boolean =
ctx.mode.is(Mode.GadtConstraintInference) || myNecessaryConstraintsOnly

/** If `trustBounds = false` we perform comparisons in a pessimistic way as follows:
* Given an abstract type `A >: L <: H`, a subtype comparison of any type
* with `A` will compare against both `L` and `H`. E.g.
*
* T <:< A if T <:< L and T <:< H
* A <:< T if L <:< T and H <:< T
*
* This restricted form makes sure we don't "forget" types when forming
* unions and intersections with abstract types that have bad bounds. E.g.
* the following example from neg/i8900.scala that @smarter came up with:
* We have a type variable X with constraints
*
* X >: 1, X >: x.M
*
* where `x` is a locally nested variable and `x.M` has bad bounds
*
* x.M >: Int | String <: Int & String
*
* If we trust bounds, then the lower bound of `X` is `x.M` since `x.M >: 1`.
* Then even if we correct levels on instantiation to eliminate the local `x`,
* it is alreay too late, we'd get `Int & String` as instance, which does not
* satisfy the original constraint `X >: 1`.
*
* But if `trustBounds` is false, we do not conclude the `x.M >: 1` since
* we compare both bounds and the upper bound `Int & String` is not a supertype
* of `1`. So the lower bound is `1 | x.M` and when we level-avoid that we
* get `1 | Int & String`, which simplifies to `Int`.
*/
protected var trustBounds = true

inline def withUntrustedBounds(op: => Type): Type =
val saved = trustBounds
trustBounds = false
try op finally trustBounds = saved

def checkReset() =
assert(addConstraintInvocations == 0)
assert(frozenConstraint == false)
Expand Down Expand Up @@ -262,16 +296,14 @@ trait ConstraintHandling {
// If `isUpper` is true, ensure that `param <: `bound`, otherwise ensure
// that `param >: bound`.
val narrowedBounds =
val savedHomogenizeArgs = homogenizeArgs
val savedTrustBounds = trustBounds
val saved = homogenizeArgs
homogenizeArgs = Config.alignArgsInAnd
try
trustBounds = false
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
else oldBounds.derivedTypeBounds(lo | bound, hi)
withUntrustedBounds(
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
else oldBounds.derivedTypeBounds(lo | bound, hi))
finally
homogenizeArgs = savedHomogenizeArgs
trustBounds = savedTrustBounds
homogenizeArgs = saved
//println(i"narrow bounds for $param from $oldBounds to $narrowedBounds")
val c1 = constraint.updateEntry(param, narrowedBounds)
(c1 eq constraint)
Expand Down Expand Up @@ -431,6 +463,49 @@ trait ConstraintHandling {
}
}

private def fixLevels(tp: Type, fromBelow: Boolean, maxLevel: Int, param: TypeParamRef)(using Context) =

def needsFix(tp: NamedType) =
(tp.prefix eq NoPrefix) && tp.symbol.nestingLevel > maxLevel

class NeedsLeveling extends TypeAccumulator[Boolean]:
if !fromBelow then variance = -1
var nestedVarsLo, nestedVarsHi: SimpleIdentitySet[TypeVar] = SimpleIdentitySet.empty
def apply(need: Boolean, tp: Type) =
need || tp.match
case tp: NamedType =>
needsFix(tp)
|| !stopBecauseStaticOrLocal(tp) && apply(need, tp.prefix)
case tp: TypeVar =>
val inst = tp.instanceOpt
if inst.exists then apply(need, inst)
else if tp.nestingLevel > maxLevel then
if variance > 0 then nestedVarsLo += tp
else if variance < 0 then nestedVarsHi += tp
else tp.nestingLevel = maxLevel
true
else false
case _ =>
foldOver(need, tp)

class LevelAvoidMap extends TypeOps.AvoidMap:
if !fromBelow then variance = -1
def toAvoid(tp: NamedType) = needsFix(tp)
//override def apply(tp: Type): Type = tp match
// case tp: LazyRef => tp
// case _ => super.apply(tp)

if ctx.isAfterTyper then tp
else
val needsLeveling = NeedsLeveling()
if needsLeveling(false, tp) then
typr.println(i"instance $tp for $param needs leveling to $maxLevel, nested = ${needsLeveling.nestedVarsLo.toList} | ${needsLeveling.nestedVarsHi.toList}")
needsLeveling.nestedVarsLo.foreach(_.instantiate(fromBelow = true))
needsLeveling.nestedVarsHi.foreach(_.instantiate(fromBelow = false))
LevelAvoidMap()(tp)
else tp
end fixLevels

/** Solve constraint set for given type parameter `param`.
* If `fromBelow` is true the parameter is approximated by its lower bound,
* otherwise it is approximated by its upper bound, unless the upper bound
Expand All @@ -442,13 +517,17 @@ trait ConstraintHandling {
* @return the instantiating type
* @pre `param` is in the constraint's domain.
*/
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
final def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type =
constraint.entry(param) match
case entry: TypeBounds =>
val useLowerBound = fromBelow || param.occursIn(entry.hi)
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
inst
val rawInst = withUntrustedBounds(
if useLowerBound then fullLowerBound(param) else fullUpperBound(param))
val levelInst = fixLevels(rawInst, fromBelow, maxLevel, param)
if levelInst ne rawInst then
typr.println(i"level avoid for $maxLevel: $rawInst --> $levelInst")
typr.println(i"approx $param, from below = $fromBelow, inst = $levelInst")
levelInst
case inst =>
assert(inst.exists, i"param = $param\nconstraint = $constraint")
inst
Expand Down Expand Up @@ -561,8 +640,8 @@ trait ConstraintHandling {
* a lower bound instantiation can be a singleton type only if the upper bound
* is also a singleton type.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
val approx = approximation(param, fromBelow).simplified
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
val approx = approximation(param, fromBelow, maxLevel).simplified
if fromBelow then
val widened = widenInferred(approx, param)
// Widening can add extra constraints, in particular the widened type might
Expand All @@ -572,7 +651,7 @@ trait ConstraintHandling {
// (we do not check for non-toplevel occurences: those should never occur
// since `addOneBound` disallows recursive lower bounds).
if constraint.occursAtToplevel(param, widened) then
instanceType(param, fromBelow)
instanceType(param, fromBelow, maxLevel)
else
widened
else
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ object Contexts {
protected def scope_=(scope: Scope): Unit = _scope = scope
final def scope: Scope = _scope

/** The current type comparer */
/** The current typerstate */
private var _typerState: TyperState = _
protected def typerState_=(typerState: TyperState): Unit = _typerState = typerState
final def typerState: TyperState = _typerState
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ sealed abstract class GadtConstraint extends Showable {
def isNarrowing: Boolean

/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type

def symbols: List[Symbol]

Expand Down Expand Up @@ -205,9 +205,9 @@ final class ProperGadtConstraint private(

def isNarrowing: Boolean = wasConstrained

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = {
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
val res =
approximation(tvarOrError(sym).origin, fromBelow = fromBelow) match
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
case tpr: TypeParamRef =>
// Here we do externalization when the returned type is a TypeParamRef,
// b/c ConstraintHandling.approximation may return internal types when
Expand Down Expand Up @@ -317,7 +317,7 @@ final class ProperGadtConstraint private(
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2863,11 +2863,11 @@ object TypeComparer {
def subtypeCheckInProgress(using Context): Boolean =
comparing(_.subtypeCheckInProgress)

def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.instanceType(param, fromBelow))
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, maxLevel))

def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.approximation(param, fromBelow))
def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.approximation(param, fromBelow, maxLevel))

def bounds(param: TypeParamRef)(using Context): TypeBounds =
comparing(_.bounds(param))
Expand Down Expand Up @@ -2953,7 +2953,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
case param @ TypeParamRef(b, n) if b eq caseLambda =>
insts(n) =
if canApprox then
approximation(param, fromBelow = variance >= 0).simplified
approximation(param, fromBelow = variance >= 0, Int.MaxValue).simplified
else constraint.entry(param) match
case entry: TypeBounds =>
val lo = fullLowerBound(param)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4509,7 +4509,7 @@ object Types {
* - On instantiation, replacing any param in the param bound
* with a level greater than nestingLevel (see `fullLowerBound`).
*/
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState | Null, val nestingLevel: Int) extends CachedProxyType with ValueType {
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState | Null, var nestingLevel: Int) extends CachedProxyType with ValueType {
private var currentOrigin = initOrigin

def origin: TypeParamRef = currentOrigin
Expand Down Expand Up @@ -4574,7 +4574,7 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(using Context): Type =
val tp = TypeComparer.instanceType(origin, fromBelow)
val tp = TypeComparer.instanceType(origin, fromBelow, nestingLevel)
if myInst.exists then // The line above might have triggered instantiation of the current type variable
myInst
else
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class InlineReducer(inliner: Inliner)(using Context):

def addTypeBindings(typeBinds: TypeBindsMap)(using Context): Unit =
typeBinds.foreachBinding { case (sym, shouldBeMinimized) =>
newTypeBinding(sym, ctx.gadt.approximation(sym, fromBelow = shouldBeMinimized))
newTypeBinding(sym, ctx.gadt.approximation(sym, fromBelow = shouldBeMinimized, Int.MaxValue))
}

def registerAsGadtSyms(typeBinds: TypeBindsMap)(using Context): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ object TypeTestsCasts {
private[transform] def foundClasses(tp: Type)(using Context): List[Symbol] =
def go(tp: Type, acc: List[Type])(using Context): List[Type] = tp.dealias match
case OrType(tp1, tp2) => go(tp2, go(tp1, acc))
case AndType(tp1, tp2) => (for t1 <- go(tp1, Nil); t2 <- go(tp2, Nil); yield AndType(t1, t2)) ::: acc
case AndType(tp1, tp2) => (for t1 <- go(tp1, Nil); t2 <- go(tp2, Nil) yield AndType(t1, t2)) ::: acc
case _ => tp :: acc
go(tp, Nil).map(effectiveClass)
}
3 changes: 0 additions & 3 deletions compiler/test/dotc/pos-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,3 @@ i4176-gadt.scala
i13974a.scala

java-inherited-type1

# avoidance bug
i15174.scala
2 changes: 1 addition & 1 deletion tests/pending/run/i8861.scala → tests/neg/i8861.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object Test {
// now infers `c.visit[(Int & M | String & M)]`
def minimalFail[M](c: Container { type A = M }): M = c.visit(
int = vi => vi.i : vi.A,
str = vs => vs.t : vs.A
str = vs => vs.t : vs.A // error
)

def main(args: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ object Test {
def inv(cond: Boolean) = // used to leak: Inv[x.type]
if (cond)
val x: Int = 1
new Inv(x)
new Inv(x) // error
else
Inv.empty
Inv.empty // error

}
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/pos/i15595.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trait MatchResult[+T]

@main def Test() = {
def convert[T <: Seq[_], U <: MatchResult[_]](fn: T => U)(implicit x: Seq[_] = Seq.empty): U = ???
def resultOf[T](v: T): MatchResult[T] = ???

convert { _ =>
type R = String
resultOf[R](???)
// this would not lead to crash:
// val x = resultOf[R](???)
// x
}
}

0 comments on commit 824496b

Please sign in to comment.