Skip to content

Commit

Permalink
Merge pull request #8867 from dotty-staging/fix-#8861
Browse files Browse the repository at this point in the history
Fix #8861: Avoid parameters when instantiating closure results
  • Loading branch information
odersky authored May 9, 2020
2 parents 1896c2b + e97b278 commit b4338a8
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 30 deletions.
14 changes: 8 additions & 6 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ trait ConstraintHandling[AbstractContext] {

/** Widen inferred type `inst` with upper `bound`, according to the following rules:
* 1. If `inst` is a singleton type, or a union containing some singleton types,
* widen (all) the singleton type(s), provied the result is a subtype of `bound`
* widen (all) the singleton type(s), provided the result is a subtype of `bound`
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provied the result is a subtype of `bound`.
* of all common base types, provided the result is a subtype of `bound`.
*
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
* Also, if the result of these widenings is a TypeRef to a module class,
Expand All @@ -312,15 +312,17 @@ trait ConstraintHandling[AbstractContext] {
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
def widenOr(tp: Type) = {
val tpw = tp.widenUnion
if ((tpw ne tp) && tpw <:< bound) tpw else tp
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
}
def widenSingle(tp: Type) = {
val tpw = tp.widenSingletons
if ((tpw ne tp) && tpw <:< bound) tpw else tp
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
}
def isSingleton(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
val wideInst =
if (isSubTypeWhenFrozen(bound, defn.SingletonType)) inst
else widenOr(widenSingle(inst))
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ final class ProperGadtConstraint private(
)

val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
val tv = new TypeVar(paramRef, creatorState = null)
val tv = TypeVar(paramRef, creatorState = null)
mapping = mapping.updated(sym, tv)
reverseMapping = reverseMapping.updated(tv.origin, sym)
tv
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,14 @@ object SymDenotations {
else if is(Contravariant) then Contravariant
else EmptyFlags

/** The length of the owner chain of this symbol. 1 for _root_, 0 for NoSymbol */
def nestingLevel(using Context): Int =
@tailrec def recur(d: SymDenotation, n: Int): Int = d match
case NoDenotation => n
case d: ClassDenotation => d.nestingLevel + n // profit from the cache in ClassDenotation
case _ => recur(d.owner, n + 1)
recur(this, 0)

/** The flags to be used for a type parameter owned by this symbol.
* Overridden by ClassDenotation.
*/
Expand Down Expand Up @@ -2160,6 +2168,12 @@ object SymDenotations {

override def registeredCompanion(implicit ctx: Context) = { ensureCompleted(); myCompanion }
override def registeredCompanion_=(c: Symbol) = { myCompanion = c }

private var myNestingLevel = -1

override def nestingLevel(using Context) =
if myNestingLevel == -1 then myNestingLevel = owner.nestingLevel + 1
myNestingLevel
}

/** The denotation of a package class.
Expand Down
46 changes: 39 additions & 7 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4110,18 +4110,17 @@ object Types {
*
* @param origin The parameter that's tracked by the type variable.
* @param creatorState The typer state in which the variable was created.
*
* `owningTree` and `owner` are used to determine whether a type-variable can be instantiated
* at some given point. See `Inferencing#interpolateUndetVars`.
*/
final class TypeVar(private var _origin: TypeParamRef, creatorState: TyperState) extends CachedProxyType with ValueType {
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int) extends CachedProxyType with ValueType {

private var currentOrigin = initOrigin

def origin: TypeParamRef = _origin
def origin: TypeParamRef = currentOrigin

/** Set origin to new parameter. Called if we merge two conflicting constraints.
* See OrderingConstraint#merge, OrderingConstraint#rename
*/
def setOrigin(p: TypeParamRef) = _origin = p
def setOrigin(p: TypeParamRef) = currentOrigin = p

/** The permanent instance type of the variable, or NoType is none is given yet */
private var myInst: Type = NoType
Expand Down Expand Up @@ -4150,6 +4149,36 @@ object Types {
/** Is the variable already instantiated? */
def isInstantiated(implicit ctx: Context): Boolean = instanceOpt.exists

/** Avoid term references in `tp` to parameters or local variables that
* are nested more deeply than the type variable itself.
*/
private def avoidCaptures(tp: Type)(using Context): Type =
val problemSyms = new TypeAccumulator[Set[Symbol]]:
def apply(syms: Set[Symbol], t: Type): Set[Symbol] = t match
case ref @ TermRef(NoPrefix, _)
// AVOIDANCE TODO: Are there other problematic kinds of references?
// Our current tests only give us these, but we might need to generalize this.
if ref.symbol.maybeOwner.nestingLevel > nestingLevel =>
syms + ref.symbol
case _ =>
foldOver(syms, t)
val problems = problemSyms(Set.empty, tp)
if problems.isEmpty then tp
else
val atp = ctx.typer.avoid(tp, problems.toList)
def msg = i"Inaccessible variables captured in instantation of type variable $this.\n$tp was fixed to $atp"
typr.println(msg)
val bound = ctx.typeComparer.fullUpperBound(origin)
if !(atp <:< bound) then
throw new TypeError(s"$msg,\nbut the latter type does not conform to the upper bound $bound")
atp
// AVOIDANCE TODO: This really works well only if variables are instantiated from below
// If we hit a problematic symbol while instantiating from above, then avoidance
// will widen the instance type further. This could yield an alias, which would be OK.
// But it also could yield a true super type which would then fail the bounds check
// and throw a TypeError. The right thing to do instead would be to avoid "downwards".
// To do this, we need first test cases for that situation.

/** Instantiate variable with given type */
def instantiateWith(tp: Type)(implicit ctx: Context): Type = {
assert(tp ne this, s"self instantiation of ${tp.show}, constraint = ${ctx.typerState.constraint.show}")
Expand All @@ -4168,7 +4197,7 @@ object Types {
* is also a singleton type.
*/
def instantiate(fromBelow: Boolean)(implicit ctx: Context): Type =
instantiateWith(ctx.typeComparer.instanceType(origin, fromBelow))
instantiateWith(avoidCaptures(ctx.typeComparer.instanceType(origin, fromBelow)))

/** For uninstantiated type variables: Is the lower bound different from Nothing? */
def hasLowerBound(implicit ctx: Context): Boolean =
Expand Down Expand Up @@ -4200,6 +4229,9 @@ object Types {
s"TypeVar($origin$instStr)"
}
}
object TypeVar:
def apply(initOrigin: TypeParamRef, creatorState: TyperState)(using Context) =
new TypeVar(initOrigin, creatorState, ctx.owner.nestingLevel)

type TypeVars = SimpleIdentitySet[TypeVar]

Expand Down
37 changes: 23 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ class Namer { typer: Typer =>

import untpd._

val TypedAhead: Property.Key[tpd.Tree] = new Property.Key
val ExpandedTree: Property.Key[untpd.Tree] = new Property.Key
val TypedAhead : Property.Key[tpd.Tree] = new Property.Key
val ExpandedTree : Property.Key[untpd.Tree] = new Property.Key
val ExportForwarders: Property.Key[List[tpd.MemberDef]] = new Property.Key
val SymOfTree: Property.Key[Symbol] = new Property.Key
val Deriver: Property.Key[typer.Deriver] = new Property.Key
val SymOfTree : Property.Key[Symbol] = new Property.Key
val Deriver : Property.Key[typer.Deriver] = new Property.Key

/** A partial map from unexpanded member and pattern defs and to their expansions.
* Populated during enterSyms, emptied during typer.
Expand Down Expand Up @@ -1440,13 +1440,10 @@ class Namer { typer: Typer =>
// instead of widening to the underlying module class types.
// We also drop the @Repeated annotation here to avoid leaking it in method result types
// (see run/inferred-repeated-result).
def widenRhs(tp: Type): Type = {
val tp1 = tp.widenTermRefExpr.simplified match
def widenRhs(tp: Type): Type =
tp.widenTermRefExpr.simplified match
case ctp: ConstantType if isInlineVal => ctp
case ref: TypeRef if ref.symbol.is(ModuleClass) => tp
case tp => tp.widenUnion
tp1.dropRepeatedAnnot
}
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
Expand Down Expand Up @@ -1498,9 +1495,21 @@ class Namer { typer: Typer =>
if (isFullyDefined(tpe, ForceDegree.none)) tpe
else typedAheadExpr(mdef.rhs, tpe).tpe
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe
mdef match {
case mdef: DefDef if mdef.name == nme.ANON_FUN =>
// This case applies if the closure result type contains uninstantiated
// type variables. In this case, constrain the closure result from below
// by the parameter-capture-avoiding type of the body.
val rhsType = typedAheadExpr(mdef.rhs, tpt.tpe).tpe

// The following part is important since otherwise we might instantiate
// the closure result type with a plain functon type that refers
// to local parameters. An example where this happens in `dependent-closures.scala`
// If the code after `val rhsType` is commented out, this file fails pickling tests.
// AVOIDANCE TODO: Follow up why this happens, and whether there
// are better ways to achieve this. It would be good if we could get rid of this code.
// It seems at least partially redundant with the nesting level checking on TypeVar
// instantiation.
val hygienicType = avoid(rhsType, paramss.flatten)
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
ctx.error(i"return type ${tpt.tpe} of lambda cannot be made hygienic;\n" +
Expand All @@ -1513,10 +1522,10 @@ class Namer { typer: Typer =>
case _ =>
WildcardType
}
val memTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe)
val mbrTpe = paramFn(checkSimpleKinded(typedAheadType(mdef.tpt, tptProto)).tpe)
if (ctx.explicitNulls && mdef.mods.is(JavaDefined))
JavaNullInterop.nullifyMember(sym, memTpe, mdef.mods.isAllOf(JavaEnumValue))
else memTpe
JavaNullInterop.nullifyMember(sym, mbrTpe, mdef.mods.isAllOf(JavaEnumValue))
else mbrTpe
}

/** The type signature of a DefDef with given symbol */
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ object ProtoTypes {
def newTypeVars(tl: TypeLambda): List[TypeTree] =
for (paramRef <- tl.paramRefs)
yield {
val tt = new TypeVarBinder().withSpan(owningTree.span)
val tvar = new TypeVar(paramRef, state)
val tt = TypeVarBinder().withSpan(owningTree.span)
val tvar = TypeVar(paramRef, state)
state.ownedVars += tvar
tt.withType(tvar)
}
Expand Down
31 changes: 31 additions & 0 deletions tests/neg/i8861.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
object Test {
sealed trait Container { s =>
type A
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R
}
final class IntV extends Container { s =>
type A = Int
val i: Int = 42
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = int(this)
}
final class StrV extends Container { s =>
type A = String
val t: String = "hello"
def visit[R](int: IntV & s.type => R, str: StrV & s.type => R): R = str(this)
}

def minimalOk[R](c: Container { type A = R }): R = c.visit[R](
int = vi => vi.i : vi.A,
str = vs => vs.t : vs.A
)
def minimalFail[M](c: Container { type A = M }): M = c.visit(
int = vi => vi.i : vi.A,
str = vs => vs.t : vs.A // error
)

def main(args: Array[String]): Unit = {
val e: Container { type A = String } = new StrV
println(minimalOk(e)) // this one prints "hello"
println(minimalFail(e)) // this one fails with ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
}
}
27 changes: 27 additions & 0 deletions tests/pos/dependent-closures.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
trait S { type N; def n: N }

def newS[X](n: X): S { type N = X } = ???

def test =
val ss: List[S] = ???
val cl1 = (s: S) => newS(s.n)
val cl2: (s: S) => S { type N = s.N } = cl1
def f[R](cl: (s: S) => R) = cl
val x = f(s => newS(s.n))
val x1: (s: S) => S = x
// If the code in `tptProto` of Namer that refers to this
// file is commented out, we see:
// pickling difference for the result type of the closure argument
// before pickling: S => S { type N = s.N }
// after pickling : (s: S) => S { type N = s.N }

ss.map(s => newS(s.n))
// If the code in `tptProto` of Namer that refers to this
// file is commented out, we see a pickling difference like the one above.

def g[R](cl: (s: S) => (S { type N = s.N }, R)) = ???
g(s => (newS(s.n), identity(1)))

def h(cl: (s: S) => S { type N = s.N }) = ???
h(s => newS(s.n))

0 comments on commit b4338a8

Please sign in to comment.