Skip to content

Commit

Permalink
Three changes to typing rules
Browse files Browse the repository at this point in the history
The following two rules replace scala#13657:

 1. Exploit capture monotonicity in the apply rule, as discussed in scala#14387.
 2. A rule to make typing nested classes more flexible as discussed in scala#14390.

There's also a bug fix where we now enforce a previously missing subcapturing relationship
between the capture set of parent of a class and the capture set of the class itself. Clearly
a class captures all variables captured by one of its parent classes.
  • Loading branch information
odersky committed Aug 29, 2022
1 parent 57a6bc0 commit 127a223
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 66 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ sealed abstract class CaptureSet extends Showable:
extension (x: CaptureRef) private def subsumes(y: CaptureRef) =
(x eq y)
|| y.match
case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ?
case y: TermRef => y.prefix eq x
case _ => false

/** {x} <:< this where <:< is subcapturing, but treating all variables
Expand Down
38 changes: 18 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
bindType

def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym)
if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info)

def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit =
val rhsCtx = linkConstructorParams(sym).withOwner(sym)
if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then
inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) }

def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
recheck(tree, pt)
inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) }

def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type =
recheck(tree.rhs)
Expand Down Expand Up @@ -358,21 +355,22 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
// Don't report closure nodes, since their span is a point; wait instead
// for enclosing block to preduce an error
case _ =>
val actual = tpe.widenExpr
val expected = pt.widenExpr
//println(i"check conforms $actual <:< $expected")
val isCompatible =
actual <:< expected
|| expected.isRepeatedParam
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
if !isCompatible then
recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected")
err.typeMismatch(tree.withType(tpe), expected)
else if debugSuccesses then
tree match
case _: Ident =>
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
case _ =>
checkConformsExpr(tpe, tpe.widenExpr, pt.widenExpr, tree)

def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
//println(i"check conforms $actual <:< $expected")
val isCompatible =
actual <:< expected
|| expected.isRepeatedParam
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
if !isCompatible then
recheckr.println(i"conforms failed for ${tree}: $original vs $expected")
err.typeMismatch(tree.withType(original), expected)
else if debugSuccesses then
tree match
case _: Ident =>
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
case _ =>

def checkUnit(unit: CompilationUnit)(using Context): Unit =
recheck(unit.tpdTree)
Expand Down
82 changes: 65 additions & 17 deletions compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import transform.Recheck
import Recheck.*
import scala.collection.mutable
import CaptureSet.withCaptureSetsExplained
import StdNames.nme
import reporting.trace

object CheckCaptures:
Expand Down Expand Up @@ -213,22 +214,6 @@ class CheckCaptures extends Recheck:
interpolateVarsIn(tree.tpt)
curEnv = saved

override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
val pt1 = pt match
case CapturingType(core, refs, _)
if sym.owner.isClass && !sym.owner.isExtensibleClass
&& refs.elems.contains(sym.owner.thisType) =>
val paramCaptures =
sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) =>
val pcs = p.info.captureSet
(cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst
}
val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet
pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures))
case _ =>
pt
recheck(tree, pt1)

override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type =
for param <- cls.paramGetters do
if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then
Expand All @@ -237,6 +222,8 @@ class CheckCaptures extends Recheck:
param.srcPos)
val saved = curEnv
val localSet = capturedVars(cls)
for parent <- impl.parents do
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv)
try super.recheckClassDef(tree, impl, cls)
finally curEnv = saved
Expand Down Expand Up @@ -289,9 +276,34 @@ class CheckCaptures extends Recheck:
finally curEnv = curEnv.outer
recheckFinish(result, arg, pt)

/** A specialized implementation of the apply rule from https://github.com/lampepfl/dotty/discussions/14387:
*
* E |- f: Cf (Ra -> Cr Rr)
* E |- a: Ra
* ------------------------
* E |- f a: Cr /\ {f} Rr
*
* Specialized for the case where `f` is a tracked and the arguments are pure.
* This replaces the previous rule #13657 while still allowing the code in pos/lazylists1.scala.
* We could consider generalizing to the case where the function arguments have non-empty
* capture sets as suggested in #14387, but that would make capture set computations more complex,
* so we should also evaluate the performance impact.
*/
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
includeCallCaptures(tree.symbol, tree.srcPos)
super.recheckApply(tree, pt)
super.recheckApply(tree, pt) match
case tp @ CapturingType(tp1, refs, kind) =>
tree.fun match
case Select(qual, nme.apply)
if defn.isFunctionType(qual.tpe.widen) =>
qual.tpe match
case ref: CaptureRef
if ref.isTracked && tree.args.forall(_.tpe.captureSet.isAlwaysEmpty) =>
tp.derivedCapturingType(tp1, refs ** ref.singletonCaptureSet)
.showing(i"narrow $tree: $tp --> $result", capt)
case _ => tp
case _ => tp
case tp => tp

override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
val res = super.recheck(tree, pt)
Expand Down Expand Up @@ -319,6 +331,42 @@ class CheckCaptures extends Recheck:
case _ =>
super.recheckFinish(tpe, tree, pt)

/** This method implements the rule outlined in #14390:
* When checking an expression `e: Ca Ta` against an expected type `Cx Tx`
* where the capture set of `Cx` contains this and any method inside the class
* `Cls` of `this` that contains `e` has only pure parameters, drop from `Ca`
* all references to variables or this references outside `Cls`. These are all
* accessed through this, so are already accounted for by `Cx`.
*/
override def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
def isPure(info: Type): Boolean = info match
case info: PolyType => isPure(info.resType)
case info: MethodType => info.paramInfos.forall(_.captureSet.isAlwaysEmpty) && isPure(info.resType)
case _ => true
def isPureContext(owner: Symbol, limit: Symbol): Boolean =
if owner == limit then true
else if !owner.exists then false
else isPure(owner.info) && isPureContext(owner.owner, limit)
val actual1 = (expected, actual.widen) match
case (CapturingType(ecore, erefs, _), actualw @ CapturingType(acore, arefs, _)) =>
val arefs1 = (arefs /: erefs.elems) { (arefs1, eref) =>
eref match
case eref: ThisType if isPureContext(ctx.owner, eref.cls) =>
arefs1.filter {
case aref1: TermRef => !eref.cls.isContainedIn(aref1.symbol.owner)
case aref1: ThisType => !eref.cls.isContainedIn(aref1.cls)
case _ => true
}
case _ =>
arefs1
}
if arefs1 eq arefs then actual
else actualw.derivedCapturingType(acore, arefs1)
.showing(i"healing $actual --> $result", capt)
case _ =>
actual
super.checkConformsExpr(original, actual1, expected, tree)

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
Setup(preRecheckPhase, thisPhase, recheckDef)
.traverse(ctx.compilationUnit.tpdTree)
Expand Down
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/lazylist.check
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ longer explanation available when compiling with `-explain`
17 | def tail = xs() // error: cannot have an inferred type
| ^^^^^^^^^^^^^^^
| Non-local method tail cannot have an inferred result type
| {*} lazylists.LazyList[T]
| with non-empty capture set {*}.
| {LazyCons.this.xs} lazylists.LazyList[T]
| with non-empty capture set {LazyCons.this.xs}.
| The type needs to be declared explicitly.
10 changes: 5 additions & 5 deletions tests/neg-custom-args/captures/lazylists1.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 -----------------------------------
25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: {xs, f} LazyList[A]
| Required: {Mapped.this, xs} LazyList[A]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:66 -----------------------------------
25 | def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: {xs, f} LazyList[A]
| Required: {Mapped.this, f} LazyList[A]

longer explanation available when compiling with `-explain`
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/lazylists1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ extension [A](xs: {*} LazyList[A])
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK
def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
new Mapped

32 changes: 22 additions & 10 deletions tests/neg-custom-args/captures/lazylists2.check
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,29 @@ longer explanation available when compiling with `-explain`
32 | new Mapped

longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 -----------------------------------
41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
| ^^^^^^^^^^^^^^
| Found: {f} LazyList[B]
| Required: {xs} LazyList[B]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:36:4 ------------------------------------
36 | final class Mapped extends LazyList[B]: // error
| ^
| Found: {f, xs} LazyList[B]
| Required: {xs} LazyList[B]
37 | this: ({xs} Mapped) =>
38 | def isEmpty = false
39 | def head: B = f(xs.head)
40 | def tail: {this} LazyList[B] = xs.tail.map(f)
41 | new Mapped

longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 -----------------------------------
59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
| ^^^^^^^^^^^^^^
| Found: {f} LazyList[B]
| Required: {Mapped.this} LazyList[B]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:54:4 ------------------------------------
54 | class Mapped extends LazyList[B]: // error
| ^
| Found: {f, xs} LazyList[B]
| Required: LazyList[B]
55 | this: ({xs, f} Mapped) =>
56 | def isEmpty = false
57 | def head: B = f(xs.head)
58 | def tail: {this} LazyList[B] = xs.tail.map(f)
59 | class Mapped2 extends Mapped:
60 | this: Mapped =>
61 | new Mapped2

longer explanation available when compiling with `-explain`
8 changes: 4 additions & 4 deletions tests/neg-custom-args/captures/lazylists2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ extension [A](xs: {*} LazyList[A])
new Mapped

def map3[B](f: A => B): {xs} LazyList[B] =
final class Mapped extends LazyList[B]:
final class Mapped extends LazyList[B]: // error
this: ({xs} Mapped) =>

def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // error
def tail: {this} LazyList[B] = xs.tail.map(f)
new Mapped

def map4[B](f: A => B): {xs} LazyList[B] =
Expand All @@ -51,12 +51,12 @@ extension [A](xs: {*} LazyList[A])
new Mapped

def map5[B](f: A => B): LazyList[B] =
class Mapped extends LazyList[B]:
class Mapped extends LazyList[B]: // error
this: ({xs, f} Mapped) =>

def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // error
def tail: {this} LazyList[B] = xs.tail.map(f)
class Mapped2 extends Mapped:
this: Mapped =>
new Mapped2
Expand Down
68 changes: 68 additions & 0 deletions tests/pos-custom-args/captures/lazylists-exceptions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import language.experimental.saferExceptions
import annotation.unchecked.uncheckedVariance

trait LazyList[+A]:
this: {*} LazyList[A] =>

def isEmpty: Boolean
def head: A
def tail: {this} LazyList[A]

object LazyNil extends LazyList[Nothing]:
def isEmpty: Boolean = true
def head = ???
def tail = ???

final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]:
this: {*} LazyList[T] =>

var forced = false
var cache: {this} LazyList[T @uncheckedVariance] = compiletime.uninitialized

private def force =
if !forced then
cache = xs()
forced = true
cache

def isEmpty = false
def head = x
def tail: {this} LazyList[T] = force
end LazyCons

extension [A](xs: {*} LazyList[A])
def map[B](f: A => B): {xs, f} LazyList[B] =
if xs.isEmpty then LazyNil
else LazyCons(f(xs.head), () => xs.tail.map(f))

def filter(p: A => Boolean): {xs, p} LazyList[A] =
if xs.isEmpty then LazyNil
else if p(xs.head) then LazyCons(xs.head, () => xs.tail.filter(p))
else xs.tail.filter(p)

def concat(ys: {*} LazyList[A]): {xs, ys} LazyList[A] =
if xs.isEmpty then ys
else LazyCons(xs.head, () => xs.tail.concat(ys))
end extension

class Ex1 extends Exception
class Ex2 extends Exception

def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) =
val xs = LazyCons(1, () => LazyNil)

def f(x: Int): Int throws Ex1 =
if x < 0 then throw Ex1()
x * x

def g(x: Int): Int throws Ex1 =
if x < 0 then throw Ex1()
x * x

def x1 = xs.map(f)
def x1c: {cap1} LazyList[Int] = x1

def x2 = x1.concat(xs.map(g).filter(_ > 0))
def x2c: {cap1, cap2} LazyList[Int] = x2


1 change: 0 additions & 1 deletion tests/pos-custom-args/captures/lazylists.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ extension [A](xs: {*} LazyList[A])
def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK
if xs.isEmpty then LazyNil
else new Mapped

Expand Down
Loading

0 comments on commit 127a223

Please sign in to comment.