Skip to content

Commit

Permalink
Three changes to typing rules
Browse files Browse the repository at this point in the history
The following two rules replace #13657:

 1. Exploit capture monotonicity in the apply rule, as discussed in #14387.
 2. A rule to make typing nested classes more flexible as discussed in #14390.

There's also a bug fix where we now enforce a previously missing subcapturing relationship
between the capture set of parent of a class and the capture set of the class itself. Clearly
a class captures all variables captured by one of its parent classes.
  • Loading branch information
odersky committed Jan 31, 2022
1 parent 18dd570 commit 9a3f60d
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 65 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ sealed abstract class CaptureSet extends Showable:
extension (x: CaptureRef) private def subsumes(y: CaptureRef) =
(x eq y)
|| y.match
case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ?
case y: TermRef => y.prefix eq x
case _ => false

/** {x} <:< this where <:< is subcapturing, but treating all variables
Expand Down
38 changes: 18 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
bindType

def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym)
if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info)

def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit =
val rhsCtx = linkConstructorParams(sym).withOwner(sym)
if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then
inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) }

def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
recheck(tree, pt)
inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) }

def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type =
recheck(tree.rhs)
Expand Down Expand Up @@ -358,21 +355,22 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
// Don't report closure nodes, since their span is a point; wait instead
// for enclosing block to preduce an error
case _ =>
val actual = tpe.widenExpr
val expected = pt.widenExpr
//println(i"check conforms $actual <:< $expected")
val isCompatible =
actual <:< expected
|| expected.isRepeatedParam
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
if !isCompatible then
recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected")
err.typeMismatch(tree.withType(tpe), expected)
else if debugSuccesses then
tree match
case _: Ident =>
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
case _ =>
checkConformsExpr(tpe, tpe.widenExpr, pt.widenExpr, tree)

def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
//println(i"check conforms $actual <:< $expected")
val isCompatible =
actual <:< expected
|| expected.isRepeatedParam
&& actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass))
if !isCompatible then
recheckr.println(i"conforms failed for ${tree}: $original vs $expected")
err.typeMismatch(tree.withType(original), expected)
else if debugSuccesses then
tree match
case _: Ident =>
println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}")
case _ =>

def checkUnit(unit: CompilationUnit)(using Context): Unit =
recheck(unit.tpdTree)
Expand Down
82 changes: 65 additions & 17 deletions compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import transform.Recheck
import Recheck.*
import scala.collection.mutable
import CaptureSet.withCaptureSetsExplained
import StdNames.nme
import reporting.trace

object CheckCaptures:
Expand Down Expand Up @@ -213,22 +214,6 @@ class CheckCaptures extends Recheck:
interpolateVarsIn(tree.tpt)
curEnv = saved

override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type =
val pt1 = pt match
case CapturingType(core, refs, _)
if sym.owner.isClass && !sym.owner.isExtensibleClass
&& refs.elems.contains(sym.owner.thisType) =>
val paramCaptures =
sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) =>
val pcs = p.info.captureSet
(cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst
}
val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet
pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures))
case _ =>
pt
recheck(tree, pt1)

override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type =
for param <- cls.paramGetters do
if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then
Expand All @@ -237,6 +222,8 @@ class CheckCaptures extends Recheck:
param.srcPos)
val saved = curEnv
val localSet = capturedVars(cls)
for parent <- impl.parents do
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv)
try super.recheckClassDef(tree, impl, cls)
finally curEnv = saved
Expand Down Expand Up @@ -289,9 +276,34 @@ class CheckCaptures extends Recheck:
finally curEnv = curEnv.outer
recheckFinish(result, arg, pt)

/** A specialized implementation of the apply rule from https://github.com/lampepfl/dotty/discussions/14387:
*
* E |- f: Cf (Ra -> Cr Rr)
* E |- a: Ra
* ------------------------
* E |- f a: Cr /\ {f} Rr
*
* Specialized for the case where `f` is a tracked and the arguments are pure.
* This replaces the previous rule #13657 while still allowing the code in pos/lazylists1.scala.
* We could consider generalizing to the case where the function arguments have non-empty
* capture sets as suggested in #14387, but that would make capture set computations more complex,
* so we should also evaluate the performance impact.
*/
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
includeCallCaptures(tree.symbol, tree.srcPos)
super.recheckApply(tree, pt)
super.recheckApply(tree, pt) match
case tp @ CapturingType(tp1, refs, kind) =>
tree.fun match
case Select(qual, nme.apply)
if defn.isFunctionType(qual.tpe.widen) =>
qual.tpe match
case ref: CaptureRef
if ref.isTracked && tree.args.forall(_.tpe.captureSet.isAlwaysEmpty) =>
tp.derivedCapturingType(tp1, refs ** ref.singletonCaptureSet)
.showing(i"narrow $tree: $tp --> $result", capt)
case _ => tp
case _ => tp
case tp => tp

override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
val res = super.recheck(tree, pt)
Expand Down Expand Up @@ -319,6 +331,42 @@ class CheckCaptures extends Recheck:
case _ =>
super.recheckFinish(tpe, tree, pt)

/** This method implements the rule outlined in #14390:
* When checking an expression `e: Ca Ta` against an expected type `Cx Tx`
* where the capture set of `Cx` contains this and any method inside the class
* `Cls` of `this` that contains `e` has only pure parameters, drop from `Ca`
* all references to variables or this references outside `Cls`. These are all
* accessed through this, so are already accounted for by `Cx`.
*/
override def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit =
def isPure(info: Type): Boolean = info match
case info: PolyType => isPure(info.resType)
case info: MethodType => info.paramInfos.forall(_.captureSet.isAlwaysEmpty) && isPure(info.resType)
case _ => true
def isPureContext(owner: Symbol, limit: Symbol): Boolean =
if owner == limit then true
else if !owner.exists then false
else isPure(owner.info) && isPureContext(owner.owner, limit)
val actual1 = (expected, actual.widen) match
case (CapturingType(ecore, erefs, _), actualw @ CapturingType(acore, arefs, _)) =>
val arefs1 = (arefs /: erefs.elems) { (arefs1, eref) =>
eref match
case eref: ThisType if isPureContext(ctx.owner, eref.cls) =>
arefs1.filter {
case aref1: TermRef => !eref.cls.isContainedIn(aref1.symbol.owner)
case aref1: ThisType => !eref.cls.isContainedIn(aref1.cls)
case _ => true
}
case _ =>
arefs1
}
if arefs1 eq arefs then actual
else actualw.derivedCapturingType(acore, arefs1)
.showing(i"healing $actual --> $result", capt)
case _ =>
actual
super.checkConformsExpr(original, actual1, expected, tree)

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
Setup(preRecheckPhase, thisPhase, recheckDef)
.traverse(ctx.compilationUnit.tpdTree)
Expand Down
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/lazylist.check
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ longer explanation available when compiling with `-explain`
17 | def tail = xs() // error: cannot have an inferred type
| ^^^^^^^^^^^^^^^
| Non-local method tail cannot have an inferred result type
| {*} lazylists.LazyList[T]
| with non-empty capture set {*}.
| {LazyCons.this.xs} lazylists.LazyList[T]
| with non-empty capture set {LazyCons.this.xs}.
| The type needs to be declared explicitly.
10 changes: 5 additions & 5 deletions tests/neg-custom-args/captures/lazylists1.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 -----------------------------------
25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: {xs, f} LazyList[A]
| Required: {Mapped.this, xs} LazyList[A]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:66 -----------------------------------
25 | def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| Found: {xs, f} LazyList[A]
| Required: {Mapped.this, f} LazyList[A]

longer explanation available when compiling with `-explain`
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/lazylists1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ extension [A](xs: {*} LazyList[A])
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK
def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error
new Mapped

32 changes: 22 additions & 10 deletions tests/neg-custom-args/captures/lazylists2.check
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,29 @@ longer explanation available when compiling with `-explain`
32 | new Mapped

longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 -----------------------------------
41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
| ^^^^^^^^^^^^^^
| Found: {f} LazyList[B]
| Required: {xs} LazyList[B]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:36:4 ------------------------------------
36 | final class Mapped extends LazyList[B]: // error
| ^
| Found: {f, xs} LazyList[B]
| Required: {xs} LazyList[B]
37 | this: ({xs} Mapped) =>
38 | def isEmpty = false
39 | def head: B = f(xs.head)
40 | def tail: {this} LazyList[B] = xs.tail.map(f)
41 | new Mapped

longer explanation available when compiling with `-explain`
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 -----------------------------------
59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error
| ^^^^^^^^^^^^^^
| Found: {f} LazyList[B]
| Required: {Mapped.this} LazyList[B]
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:54:4 ------------------------------------
54 | class Mapped extends LazyList[B]: // error
| ^
| Found: {f, xs} LazyList[B]
| Required: LazyList[B]
55 | this: ({xs, f} Mapped) =>
56 | def isEmpty = false
57 | def head: B = f(xs.head)
58 | def tail: {this} LazyList[B] = xs.tail.map(f)
59 | class Mapped2 extends Mapped:
60 | this: Mapped =>
61 | new Mapped2

longer explanation available when compiling with `-explain`
8 changes: 4 additions & 4 deletions tests/neg-custom-args/captures/lazylists2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ extension [A](xs: {*} LazyList[A])
new Mapped

def map3[B](f: A => B): {xs} LazyList[B] =
final class Mapped extends LazyList[B]:
final class Mapped extends LazyList[B]: // error
this: ({xs} Mapped) =>

def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // error
def tail: {this} LazyList[B] = xs.tail.map(f)
new Mapped

def map4[B](f: A => B): {xs} LazyList[B] =
Expand All @@ -51,12 +51,12 @@ extension [A](xs: {*} LazyList[A])
new Mapped

def map5[B](f: A => B): LazyList[B] =
class Mapped extends LazyList[B]:
class Mapped extends LazyList[B]: // error
this: ({xs, f} Mapped) =>

def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // error
def tail: {this} LazyList[B] = xs.tail.map(f)
class Mapped2 extends Mapped:
this: Mapped =>
new Mapped2
Expand Down
49 changes: 49 additions & 0 deletions tests/pos-custom-args/captures/lazylists-exceptions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import language.experimental.saferExceptions
import annotation.unchecked.uncheckedVariance

trait LazyList[+A]:
this: {*} LazyList[A] =>

def isEmpty: Boolean
def head: A
def tail: {this} LazyList[A]

object LazyNil extends LazyList[Nothing]:
def isEmpty: Boolean = true
def head = ???
def tail = ???

final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]:
this: {*} LazyList[T] =>

var forced = false
var cache: {this} LazyList[T @uncheckedVariance] = compiletime.uninitialized

private def force =
if !forced then
cache = xs()
forced = true
cache

def isEmpty = false
def head = x
def tail: {this} LazyList[T] = force

extension [A](xs: {*} LazyList[A])
def map[B](f: A => B): {xs, f} LazyList[B] =
if xs.isEmpty then LazyNil
else LazyCons(f(xs.head), () => xs.tail.map(f))

class Ex1 extends Exception
class Ex2 extends Exception

def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) =
val xs = LazyCons(1, () => LazyNil)

def f(x: Int): Int throws Ex1 =
if x < 0 then throw Ex1()
x * x

val res = xs.map(f)
res: {cap1} LazyList[Int]

1 change: 0 additions & 1 deletion tests/pos-custom-args/captures/lazylists.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ extension [A](xs: {*} LazyList[A])
def isEmpty = false
def head: B = f(xs.head)
def tail: {this} LazyList[B] = xs.tail.map(f) // OK
def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK
if xs.isEmpty then LazyNil
else new Mapped

Expand Down
15 changes: 11 additions & 4 deletions tests/pos-custom-args/captures/lazylists1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,36 @@ trait LazyList[+A]:
def isEmpty: Boolean
def head: A
def tail: {this} LazyList[A]
def concat[B >: A](other: {*} LazyList[B]): {this, other} LazyList[B]

object LazyNil extends LazyList[Nothing]:
def isEmpty: Boolean = true
def head = ???
def tail = ???
def concat[B](other: {*} LazyList[B]): {other} LazyList[B] = other

final class LazyCons[+T](val x: T, val xs: Int => {*} LazyList[T]) extends LazyList[T]:
final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]:
this: {*} LazyList[T] =>

def isEmpty = false
def head = x
def tail: {this} LazyList[T] = xs(0)
def tail: {this} LazyList[T] = xs()
def concat[U >: T](other: {*} LazyList[U]): {this, other} LazyList[U] =
LazyCons(head, () => tail.concat(other))

extension [A](xs: {*} LazyList[A])
def map[B](f: A => B): {xs, f} LazyList[B] =
if xs.isEmpty then LazyNil
else LazyCons(f(xs.head), x => xs.tail.map(f))
else LazyCons(f(xs.head), () => xs.tail.map(f))

def test(cap1: Cap, cap2: Cap) =
def f(x: String): String = if cap1 == cap1 then "" else "a"
def g(x: String): String = if cap2 == cap2 then "" else "a"

val xs = new LazyCons("", x => if f("") == f("") then LazyNil else LazyNil)
val xs = new LazyCons("", () => if f("") == f("") then LazyNil else LazyNil)
val xsc: {cap1} LazyList[String] = xs
val ys = xs.map(g)
val ysc: {cap1, cap2} LazyList[String] = ys
val zs = new LazyCons("", () => if g("") == g("") then LazyNil else LazyNil)
val as = xs.concat(zs)
val asc: {xs, zs} LazyList[String] = as

0 comments on commit 9a3f60d

Please sign in to comment.