Skip to content

Commit

Permalink
Merge pull request scala#15746 from dotty-staging/fix-level-checking
Browse files Browse the repository at this point in the history
Do level checking on instantiation
  • Loading branch information
odersky authored Aug 29, 2022
2 parents 8a7c84c + f01abfb commit 3e20051
Show file tree
Hide file tree
Showing 19 changed files with 288 additions and 65 deletions.
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,12 @@ object Config {
*/
inline val reuseSymDenotations = true

/** If true, check levels of type variables and create fresh ones as needed.
* This is necessary for soundness (see 3ab18a9), but also causes several
* regressions that should be fixed before turning this on.
/** If `checkLevelsOnConstraints` is true, check levels of type variables
* and create fresh ones as needed when bounds are first entered intot he constraint.
* If `checkLevelsOnInstantiation` is true, allow level-incorrect constraints but
* fix levels on type variable instantiation.
*/
inline val checkLevels = false
inline val checkLevelsOnConstraints = false
inline val checkLevelsOnInstantiation = true

}
143 changes: 127 additions & 16 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import config.Printers.typr
import typer.ProtoTypes.{newTypeVar, representedParamRef}
import UnificationDirection.*
import NameKinds.AvoidNameKind
import util.SimpleIdentitySet

/** Methods for adding constraints and solving them.
*
Expand Down Expand Up @@ -74,7 +75,43 @@ trait ConstraintHandling {
protected def necessaryConstraintsOnly(using Context): Boolean =
ctx.mode.is(Mode.GadtConstraintInference) || myNecessaryConstraintsOnly

protected var trustBounds = true
/** If `trustBounds = false` we perform comparisons in a pessimistic way as follows:
* Given an abstract type `A >: L <: H`, a subtype comparison of any type
* with `A` will compare against both `L` and `H`. E.g.
*
* T <:< A if T <:< L and T <:< H
* A <:< T if L <:< T and H <:< T
*
* This restricted form makes sure we don't "forget" types when forming
* unions and intersections with abstract types that have bad bounds. E.g.
* the following example from neg/i8900.scala that @smarter came up with:
* We have a type variable X with constraints
*
* X >: 1, X >: x.M
*
* where `x` is a locally nested variable and `x.M` has bad bounds
*
* x.M >: Int | String <: Int & String
*
* If we trust bounds, then the lower bound of `X` is `x.M` since `x.M >: 1`.
* Then even if we correct levels on instantiation to eliminate the local `x`,
* it is alreay too late, we'd get `Int & String` as instance, which does not
* satisfy the original constraint `X >: 1`.
*
* But if `trustBounds` is false, we do not conclude the `x.M >: 1` since
* we compare both bounds and the upper bound `Int & String` is not a supertype
* of `1`. So the lower bound is `1 | x.M` and when we level-avoid that we
* get `1 | Int & String`, which simplifies to `Int`.
*/
private var myTrustBounds = true

inline def withUntrustedBounds(op: => Type): Type =
val saved = myTrustBounds
myTrustBounds = false
try op finally myTrustBounds = saved

def trustBounds: Boolean =
!Config.checkLevelsOnInstantiation || myTrustBounds

def checkReset() =
assert(addConstraintInvocations == 0)
Expand All @@ -97,7 +134,7 @@ trait ConstraintHandling {
level <= maxLevel
|| ctx.isAfterTyper || !ctx.typerState.isCommittable // Leaks in these cases shouldn't break soundness
|| level == Int.MaxValue // See `nestingLevel` above.
|| !Config.checkLevels
|| !Config.checkLevelsOnConstraints

/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
* fresh type variable of level `maxLevel` and return the new variable.
Expand Down Expand Up @@ -262,16 +299,14 @@ trait ConstraintHandling {
// If `isUpper` is true, ensure that `param <: `bound`, otherwise ensure
// that `param >: bound`.
val narrowedBounds =
val savedHomogenizeArgs = homogenizeArgs
val savedTrustBounds = trustBounds
val saved = homogenizeArgs
homogenizeArgs = Config.alignArgsInAnd
try
trustBounds = false
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
else oldBounds.derivedTypeBounds(lo | bound, hi)
withUntrustedBounds(
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
else oldBounds.derivedTypeBounds(lo | bound, hi))
finally
homogenizeArgs = savedHomogenizeArgs
trustBounds = savedTrustBounds
homogenizeArgs = saved
//println(i"narrow bounds for $param from $oldBounds to $narrowedBounds")
val c1 = constraint.updateEntry(param, narrowedBounds)
(c1 eq constraint)
Expand Down Expand Up @@ -431,24 +466,98 @@ trait ConstraintHandling {
}
}

/** Fix instance type `tp` by avoidance so that it does not contain references
* to types at level > `maxLevel`.
* @param tp the type to be fixed
* @param fromBelow whether type was obtained from lower bound
* @param maxLevel the maximum level of references allowed
* @param param the parameter that was instantiated
*/
private def fixLevels(tp: Type, fromBelow: Boolean, maxLevel: Int, param: TypeParamRef)(using Context) =

def needsFix(tp: NamedType) =
(tp.prefix eq NoPrefix) && tp.symbol.nestingLevel > maxLevel

/** An accumulator that determines whether levels need to be fixed
* and computes on the side sets of nested type variables that need
* to be instantiated.
*/
class NeedsLeveling extends TypeAccumulator[Boolean]:
if !fromBelow then variance = -1

/** Nested type variables that should be instiated to theor lower (respoctively
* upper) bounds.
*/
var nestedVarsLo, nestedVarsHi: SimpleIdentitySet[TypeVar] = SimpleIdentitySet.empty

def apply(need: Boolean, tp: Type) =
need || tp.match
case tp: NamedType =>
needsFix(tp)
|| !stopBecauseStaticOrLocal(tp) && apply(need, tp.prefix)
case tp: TypeVar =>
val inst = tp.instanceOpt
if inst.exists then apply(need, inst)
else if tp.nestingLevel > maxLevel then
if variance > 0 then nestedVarsLo += tp
else if variance < 0 then nestedVarsHi += tp
else
// For invariant type variables, we use a different strategy.
// Rather than instantiating to a bound and then propagating in an
// AvoidMap, change the nesting level of an invariant type
// variable to `maxLevel`. This means that the type variable will be
// instantiated later to a less nested type. If there are other references
// to the same type variable that do not come from the type undergoing
// `fixLevels`, this could lead to coarser types. But it has the potential
// to give a better approximation for the current type, since it avoids forming
// a Range in invariant position, which can lead to very coarse types further out.
constr.println(i"widening nesting level of type variable $tp from ${tp.nestingLevel} to $maxLevel")
ctx.typerState.setNestingLevel(tp, maxLevel)
true
else false
case _ =>
foldOver(need, tp)
end NeedsLeveling

class LevelAvoidMap extends TypeOps.AvoidMap:
if !fromBelow then variance = -1
def toAvoid(tp: NamedType) = needsFix(tp)

if !Config.checkLevelsOnInstantiation || ctx.isAfterTyper then tp
else
val needsLeveling = NeedsLeveling()
if needsLeveling(false, tp) then
typr.println(i"instance $tp for $param needs leveling to $maxLevel, nested = ${needsLeveling.nestedVarsLo.toList} | ${needsLeveling.nestedVarsHi.toList}")
needsLeveling.nestedVarsLo.foreach(_.instantiate(fromBelow = true))
needsLeveling.nestedVarsHi.foreach(_.instantiate(fromBelow = false))
LevelAvoidMap()(tp)
else tp
end fixLevels

/** Solve constraint set for given type parameter `param`.
* If `fromBelow` is true the parameter is approximated by its lower bound,
* otherwise it is approximated by its upper bound, unless the upper bound
* contains a reference to the parameter itself (such occurrences can arise
* for F-bounded types, `addOneBound` ensures that they never occur in the
* lower bound).
* The solved type is not allowed to contain references to types nested deeper
* than `maxLevel`.
* Wildcard types in bounds are approximated by their upper or lower bounds.
* The constraint is left unchanged.
* @return the instantiating type
* @pre `param` is in the constraint's domain.
*/
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
final def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type =
constraint.entry(param) match
case entry: TypeBounds =>
val useLowerBound = fromBelow || param.occursIn(entry.hi)
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
inst
val rawInst = withUntrustedBounds(
if useLowerBound then fullLowerBound(param) else fullUpperBound(param))
val levelInst = fixLevels(rawInst, fromBelow, maxLevel, param)
if levelInst ne rawInst then
typr.println(i"level avoid for $maxLevel: $rawInst --> $levelInst")
typr.println(i"approx $param, from below = $fromBelow, inst = $levelInst")
levelInst
case inst =>
assert(inst.exists, i"param = $param\nconstraint = $constraint")
inst
Expand Down Expand Up @@ -560,9 +669,11 @@ trait ConstraintHandling {
* lower bounds; otherwise it is the glb of its upper bounds. However,
* a lower bound instantiation can be a singleton type only if the upper bound
* is also a singleton type.
* The instance type is not allowed to contain references to types nested deeper
* than `maxLevel`.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
val approx = approximation(param, fromBelow).simplified
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
val approx = approximation(param, fromBelow, maxLevel).simplified
if fromBelow then
val widened = widenInferred(approx, param)
// Widening can add extra constraints, in particular the widened type might
Expand All @@ -572,7 +683,7 @@ trait ConstraintHandling {
// (we do not check for non-toplevel occurences: those should never occur
// since `addOneBound` disallows recursive lower bounds).
if constraint.occursAtToplevel(param, widened) then
instanceType(param, fromBelow)
instanceType(param, fromBelow, maxLevel)
else
widened
else
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ object Contexts {
protected def scope_=(scope: Scope): Unit = _scope = scope
final def scope: Scope = _scope

/** The current type comparer */
/** The current typerstate */
private var _typerState: TyperState = _
protected def typerState_=(typerState: TyperState): Unit = _typerState = typerState
final def typerState: TyperState = _typerState
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ sealed abstract class GadtConstraint extends Showable {
def isNarrowing: Boolean

/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type

def symbols: List[Symbol]

Expand Down Expand Up @@ -205,9 +205,9 @@ final class ProperGadtConstraint private(

def isNarrowing: Boolean = wasConstrained

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = {
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
val res =
approximation(tvarOrError(sym).origin, fromBelow = fromBelow) match
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
case tpr: TypeParamRef =>
// Here we do externalization when the returned type is a TypeParamRef,
// b/c ConstraintHandling.approximation may return internal types when
Expand Down Expand Up @@ -317,7 +317,7 @@ final class ProperGadtConstraint private(
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2869,11 +2869,11 @@ object TypeComparer {
def subtypeCheckInProgress(using Context): Boolean =
comparing(_.subtypeCheckInProgress)

def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.instanceType(param, fromBelow))
def instanceType(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.instanceType(param, fromBelow, maxLevel))

def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
comparing(_.approximation(param, fromBelow))
def approximation(param: TypeParamRef, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type =
comparing(_.approximation(param, fromBelow, maxLevel))

def bounds(param: TypeParamRef)(using Context): TypeBounds =
comparing(_.bounds(param))
Expand Down Expand Up @@ -2953,7 +2953,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
case param @ TypeParamRef(b, n) if b eq caseLambda =>
insts(n) =
if canApprox then
approximation(param, fromBelow = variance >= 0).simplified
approximation(param, fromBelow = variance >= 0, Int.MaxValue).simplified
else constraint.entry(param) match
case entry: TypeBounds =>
val lo = fullLowerBound(param)
Expand Down
47 changes: 36 additions & 11 deletions compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import config.Config
import config.Printers.constr
import collection.mutable
import java.lang.ref.WeakReference
import util.Stats
import util.{Stats, SimpleIdentityMap}
import Decorators._

import scala.annotation.internal.sharable
Expand All @@ -23,24 +23,26 @@ object TyperState {
.setReporter(new ConsoleReporter())
.setCommittable(true)

opaque type Snapshot = (Constraint, TypeVars, TypeVars)
type LevelMap = SimpleIdentityMap[TypeVar, Integer]

opaque type Snapshot = (Constraint, TypeVars, LevelMap)

extension (ts: TyperState)
def snapshot()(using Context): Snapshot =
var previouslyInstantiated: TypeVars = SimpleIdentitySet.empty
for tv <- ts.ownedVars do if tv.inst.exists then previouslyInstantiated += tv
(ts.constraint, ts.ownedVars, previouslyInstantiated)
(ts.constraint, ts.ownedVars, ts.upLevels)

def resetTo(state: Snapshot)(using Context): Unit =
val (c, tvs, previouslyInstantiated) = state
for tv <- tvs do
if tv.inst.exists && !previouslyInstantiated.contains(tv) then
val (constraint, ownedVars, upLevels) = state
for tv <- ownedVars do
if !ts.ownedVars.contains(tv) then // tv has been instantiated
tv.resetInst(ts)
ts.ownedVars = tvs
ts.constraint = c
ts.constraint = constraint
ts.ownedVars = ownedVars
ts.upLevels = upLevels
}

class TyperState() {
import TyperState.LevelMap

private var myId: Int = _
def id: Int = myId
Expand Down Expand Up @@ -89,6 +91,8 @@ class TyperState() {
def ownedVars: TypeVars = myOwnedVars
def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs

private var upLevels: LevelMap = _

/** Initializes all fields except reporter, isCommittable, which need to be
* set separately.
*/
Expand All @@ -99,20 +103,35 @@ class TyperState() {
this.myConstraint = constraint
this.previousConstraint = constraint
this.myOwnedVars = SimpleIdentitySet.empty
this.upLevels = SimpleIdentityMap.empty
this.isCommitted = false
this

/** A fresh typer state with the same constraint as this one. */
def fresh(reporter: Reporter = StoreReporter(this.reporter, fromTyperState = true),
committable: Boolean = this.isCommittable): TyperState =
util.Stats.record("TyperState.fresh")
TyperState().init(this, this.constraint)
val ts = TyperState().init(this, this.constraint)
.setReporter(reporter)
.setCommittable(committable)
ts.upLevels = upLevels
ts

/** The uninstantiated variables */
def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars

/** The nestingLevel of `tv` in this typer state */
def nestingLevel(tv: TypeVar): Int =
val own = upLevels(tv)
if own == null then tv.initNestingLevel else own.intValue()

/** Set the nestingLevel of `tv` in this typer state
* @pre this level must be smaller than `tv.initNestingLevel`
*/
def setNestingLevel(tv: TypeVar, level: Int) =
assert(level < tv.initNestingLevel)
upLevels = upLevels.updated(tv, level)

/** The closest ancestor of this typer state (including possibly this typer state itself)
* which is not yet committed, or which does not have a parent.
*/
Expand Down Expand Up @@ -164,6 +183,12 @@ class TyperState() {
if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar)
else
targetState.mergeConstraintWith(this)

upLevels.foreachBinding { (tv, level) =>
if level < targetState.nestingLevel(tv) then
targetState.setNestingLevel(tv, level)
}

targetState.gc()
isCommitted = true
ownedVars = SimpleIdentitySet.empty
Expand Down
Loading

0 comments on commit 3e20051

Please sign in to comment.