Skip to content

Commit

Permalink
Merge pull request #14381 from dotty-staging/cc-root-captures
Browse files Browse the repository at this point in the history
New scheme to reject root captures
  • Loading branch information
odersky authored Feb 1, 2022
2 parents 6d672b1 + 18dd570 commit 90fa052
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 105 deletions.
26 changes: 25 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dotc
package cc

import core.*
import Types.*, Symbols.*, Contexts.*, Annotations.*
import Types.*, Symbols.*, Contexts.*, Annotations.*, Flags.*
import ast.{tpd, untpd}
import Decorators.*, NameOps.*
import config.Printers.capt
Expand Down Expand Up @@ -85,3 +85,27 @@ extension (tp: Type)
isImpure = true).appliedTo(args)
case _ =>
tp

extension (sym: Symbol)

/** Does this symbol allow results carrying the universal capability?
* Currently this is true only for function type applies (since their
* results are unboxed) and `erasedValue` since this function is magic in
* that is allows to conjure global capabilies from nothing (aside: can we find a
* more controlled way to achieve this?).
* But it could be generalized to other functions that so that they can take capability
* classes as arguments.
*/
def allowsRootCapture(using Context): Boolean =
sym == defn.Compiletime_erasedValue
|| defn.isFunctionClass(sym.maybeOwner)

def unboxesResult(using Context): Boolean =
def containsEnclTypeParam(tp: Type): Boolean = tp.strippedDealias match
case tp @ TypeRef(pre: ThisType, _) => tp.symbol.is(Param)
case tp: TypeParamRef => true
case tp: AndOrType => containsEnclTypeParam(tp.tp1) || containsEnclTypeParam(tp.tp2)
case tp: RefinedType => containsEnclTypeParam(tp.parent) || containsEnclTypeParam(tp.refinedInfo)
case _ => false
containsEnclTypeParam(sym.info.finalResultType)
&& !sym.allowsRootCapture
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ sealed abstract class CaptureSet extends Showable:
def - (ref: CaptureRef)(using Context): CaptureSet =
this -- ref.singletonCaptureSet

def disallowRootCapability(handler: () => Unit)(using Context): this.type =
if isUniversal then handler()
this

def filter(p: CaptureRef => Boolean)(using Context): CaptureSet =
if this.isConst then
val elems1 = elems.filter(p)
Expand Down Expand Up @@ -276,6 +280,7 @@ object CaptureSet:
var deps: Deps = emptySet
def isConst = isSolved
def isAlwaysEmpty = false
var addRootHandler: () => Unit = () => ()

private def recordElemsState()(using VarState): Boolean =
varState.getElems(this) match
Expand All @@ -296,6 +301,7 @@ object CaptureSet:
def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
if !isConst && recordElemsState() then
elems ++= newElems
if isUniversal then addRootHandler()
// assert(id != 2 || elems.size != 2, this)
(CompareResult.OK /: deps) { (r, dep) =>
r.andAlso(dep.tryInclude(newElems, this))
Expand All @@ -312,6 +318,10 @@ object CaptureSet:
else
CompareResult.fail(this)

override def disallowRootCapability(handler: () => Unit)(using Context): this.type =
addRootHandler = handler
super.disallowRootCapability(handler)

private var computingApprox = false

final def upperApprox(origin: CaptureSet)(using Context): CaptureSet =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
case tree: Alternative => recheckAlternative(tree, pt)
case tree: PackageDef => recheckPackageDef(tree)
case tree: Thicket => defn.NothingType
case tree: Import => defn.NothingType

tree match
case tree: NameTree => recheckNamed(tree, pt)
Expand Down
77 changes: 20 additions & 57 deletions compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,6 @@ object CheckCaptures:
if remaining.accountsFor(firstRef) then
report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos)

/** Does this function allow type arguments carrying the universal capability?
* Currently this is true only for `erasedValue` since this function is magic in
* that is allows to conjure global capabilies from nothing (aside: can we find a
* more controlled way to achieve this?).
* But it could be generalized to other functions that so that they can take capability
* classes as arguments.
*/
private def allowUniversalArguments(fn: Tree)(using Context): Boolean =
fn.symbol == defn.Compiletime_erasedValue

class CheckCaptures extends Recheck:
thisPhase =>

Expand Down Expand Up @@ -309,6 +299,26 @@ class CheckCaptures extends Recheck:
includeBoxedCaptures(res, tree.srcPos)
res

override def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type =
val typeToCheck = tree match
case _: Ident | _: Select | _: Apply | _: TypeApply if tree.symbol.unboxesResult =>
tpe
case _: Try =>
tpe
case ValDef(_, tpt, _) if tree.symbol.is(Mutable) =>
tree.symbol.info
case _ =>
NoType
if typeToCheck.exists then
typeToCheck.widenDealias match
case wtp @ CapturingType(parent, refs, _) =>
refs.disallowRootCapability { () =>
val kind = if tree.isInstanceOf[ValDef] then "mutable variable" else "expression"
report.error(em"the $kind's type $wtp is not allowed to capture the root capability `*`", tree.srcPos)
}
case _ =>
super.recheckFinish(tpe, tree, pt)

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
Setup(preRecheckPhase, thisPhase, recheckDef)
.traverse(ctx.compilationUnit.tpdTree)
Expand All @@ -319,45 +329,6 @@ class CheckCaptures extends Recheck:
show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing
}

def checkNotGlobal(tree: Tree, tp: Type, isVar: Boolean, allArgs: Tree*)(using Context): Unit =
for ref <- tp.captureSet.elems do
val isGlobal = ref match
case ref: TermRef => ref.isRootCapability
case _ => false
if isGlobal then
val what = if ref.isRootCapability then "universal" else "global"
val notAllowed = i" is not allowed to capture the $what capability $ref"
def msg =
if allArgs.isEmpty then
i"${if isVar then "type of mutable variable" else "result type"} ${tree.knownType}$notAllowed"
else tree match
case tree: InferredTypeTree =>
i"""inferred type argument ${tree.knownType}$notAllowed
|
|The inferred arguments are: [${allArgs.map(_.knownType)}%, %]"""
case _ => s"type argument$notAllowed"
report.error(msg, tree.srcPos)

def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit =
tree match
case LambdaTypeTree(_, restpt) =>
checkNotGlobal(restpt, allArgs*)
case _ =>
checkNotGlobal(tree, tree.knownType, isVar = false, allArgs*)

def checkNotGlobalDeep(tree: Tree)(using Context): Unit =
val checker = new TypeTraverser:
def traverse(tp: Type): Unit = tp match
case tp: TypeRef =>
tp.info match
case TypeBounds(_, hi) => traverse(hi)
case _ =>
case tp: TermRef =>
case _ =>
checkNotGlobal(tree, tp, isVar = true)
traverseChildren(tp)
checker.traverse(tree.knownType)

object PostCheck extends TreeTraverser:
def traverse(tree: Tree)(using Context) = trace{i"post check $tree"} {
tree match
Expand All @@ -370,10 +341,6 @@ class CheckCaptures extends Recheck:
checkWellformedPost(annot.tree)
case _ =>
}
case tree1 @ TypeApply(fn, args) if !allowUniversalArguments(fn) =>
for arg <- args do
//println(i"checking $arg in $tree: ${tree.knownType.captureSet}")
checkNotGlobal(arg, args*)
case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] =>
val sym = t.symbol
val isLocal =
Expand All @@ -396,10 +363,6 @@ class CheckCaptures extends Recheck:
|The type needs to be declared explicitly.""", t.srcPos)
case _ =>
inferred.foreachPart(checkPure, StopAt.Static)
case t: ValDef if t.symbol.is(Mutable) =>
checkNotGlobalDeep(t.tpt)
case t: Try =>
checkNotGlobal(t)
case _ =>
traverseChildren(tree)
}
Expand Down
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/capt-test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def handle[E <: Exception, R <: Top](op: (CanThrow[E]) => R)(handler: E => R): R
catch case ex: E => handler(ex)

def test: Unit =
val b = handle[Exception, () => Nothing] { // error
val b = handle[Exception, () => Nothing] {
(x: CanThrow[Exception]) => () => raise(new Exception)(using x)
} {
} { // error
(ex: Exception) => ???
}
26 changes: 19 additions & 7 deletions tests/neg-custom-args/captures/real-try.check
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
-- Error: tests/neg-custom-args/captures/real-try.scala:10:2 -----------------------------------------------------------
10 | try // error
-- Error: tests/neg-custom-args/captures/real-try.scala:12:2 -----------------------------------------------------------
12 | try // error
| ^
| result type {*} () -> Unit is not allowed to capture the universal capability *.type
11 | () => foo(1)
12 | catch
13 | case _: Ex1 => ???
14 | case _: Ex2 => ???
| the expression's type {*} () -> Unit is not allowed to capture the root capability `*`
13 | () => foo(1)
14 | catch
15 | case _: Ex1 => ???
16 | case _: Ex2 => ???
-- Error: tests/neg-custom-args/captures/real-try.scala:18:2 -----------------------------------------------------------
18 | try // error
| ^
| the expression's type {*} () -> ? Cell[Unit] is not allowed to capture the root capability `*`
19 | () => Cell(foo(1))
20 | catch
21 | case _: Ex1 => ???
22 | case _: Ex2 => ???
-- Error: tests/neg-custom-args/captures/real-try.scala:30:4 -----------------------------------------------------------
30 | b.x // error
| ^^^
| the expression's type box {*} () -> Unit is not allowed to capture the root capability `*`
16 changes: 16 additions & 0 deletions tests/neg-custom-args/captures/real-try.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,25 @@ class Ex2 extends Exception("Ex2")
def foo(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit =
if i > 0 then throw new Ex1 else throw new Ex2

class Cell[+T](val x: T)

def test() =
try // error
() => foo(1)
catch
case _: Ex1 => ???
case _: Ex2 => ???

try // error
() => Cell(foo(1))
catch
case _: Ex1 => ???
case _: Ex2 => ???

val b = try // ok here, but error on use
Cell(() => foo(1))//: Cell[box {ev} () => Unit] <: Cell[box {*} () => Unit]
catch
case _: Ex1 => ???
case _: Ex2 => ???

b.x // error
46 changes: 30 additions & 16 deletions tests/neg-custom-args/captures/try.check
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
-- Error: tests/neg-custom-args/captures/try.scala:24:3 ----------------------------------------------------------------
22 | val a = handle[Exception, CanThrow[Exception]] {
23 | (x: CanThrow[Exception]) => x
24 | }{ // error
| ^
| the expression's type {*} CT[Exception] is not allowed to capture the root capability `*`
25 | (ex: Exception) => ???
26 | }
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------
28 | val b = handle[Exception, () -> Nothing] { // error
| ^
Expand All @@ -7,19 +15,25 @@
30 | } {

longer explanation available when compiling with `-explain`
-- Error: tests/neg-custom-args/captures/try.scala:22:28 ---------------------------------------------------------------
22 | val a = handle[Exception, CanThrow[Exception]] { // error
| ^^^^^^^^^^^^^^^^^^^
| type argument is not allowed to capture the universal capability (* : Any)
-- Error: tests/neg-custom-args/captures/try.scala:34:11 ---------------------------------------------------------------
34 | val xx = handle { // error
| ^^^^^^
| inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any)
|
| The inferred arguments are: [? Exception, {*} () -> Int]
-- Error: tests/neg-custom-args/captures/try.scala:46:13 ---------------------------------------------------------------
46 |val global = handle { // error
| ^^^^^^
| inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any)
|
| The inferred arguments are: [? Exception, {*} () -> Int]
-- Error: tests/neg-custom-args/captures/try.scala:39:4 ----------------------------------------------------------------
34 | val xx = handle {
35 | (x: CanThrow[Exception]) =>
36 | () =>
37 | raise(new Exception)(using x)
38 | 22
39 | } { // error
| ^
| the expression's type {*} () -> Int is not allowed to capture the root capability `*`
40 | (ex: Exception) => () => 22
41 | }
-- Error: tests/neg-custom-args/captures/try.scala:51:2 ----------------------------------------------------------------
46 |val global = handle {
47 | (x: CanThrow[Exception]) =>
48 | () =>
49 | raise(new Exception)(using x)
50 | 22
51 |} { // error
| ^
| the expression's type {*} () -> Int is not allowed to capture the root capability `*`
52 | (ex: Exception) => () => 22
53 |}
12 changes: 6 additions & 6 deletions tests/neg-custom-args/captures/try.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R =
catch case ex: E => handler(ex)

def test =
val a = handle[Exception, CanThrow[Exception]] { // error
val a = handle[Exception, CanThrow[Exception]] {
(x: CanThrow[Exception]) => x
}{
}{ // error
(ex: Exception) => ???
}

Expand All @@ -31,23 +31,23 @@ def test =
(ex: Exception) => ???
}

val xx = handle { // error
val xx = handle {
(x: CanThrow[Exception]) =>
() =>
raise(new Exception)(using x)
22
} {
} { // error
(ex: Exception) => () => 22
}
val yy = xx :: Nil
yy // OK


val global = handle { // error
val global = handle {
(x: CanThrow[Exception]) =>
() =>
raise(new Exception)(using x)
22
} {
} { // error
(ex: Exception) => () => 22
}
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/try3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing =

@main def Test: Int =
def f(a: Boolean) =
handle { // error
handle {
if !a then raise(IOException())
(b: Boolean) =>
if !b then raise(IOException())
0
} {
} { // error
ex => (b: Boolean) => -1
}
val g = f(true)
Expand Down
Loading

0 comments on commit 90fa052

Please sign in to comment.