Skip to content

Commit

Permalink
Turn on separation checking for applications
Browse files Browse the repository at this point in the history
 - Use unsafeAssumeSeparate(...) as an escape hatch
  • Loading branch information
odersky committed Dec 21, 2024
1 parent 3a26fe8 commit 8cbc022
Show file tree
Hide file tree
Showing 24 changed files with 254 additions and 66 deletions.
20 changes: 18 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ object CheckCaptures:

/** Was a new type installed for this tree? */
def hasNuType: Boolean

/** Is this tree passed to a parameter or assigned to a value with a type
* that contains cap in no-flip covariant position, which will necessite
* a separation check?
*/
def needsSepCheck: Boolean
end CheckerAPI

class CheckCaptures extends Recheck, SymTransformer:
Expand Down Expand Up @@ -279,6 +285,12 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]

/** Trees that will need a separation check because they contain cap */
private val sepCheckable = util.EqHashSet[Tree]()

extension [T <: Tree](tree: T)
def needsSepCheck: Boolean = sepCheckable.contains(tree)

/** Instantiate capture set variables appearing contra-variantly to their
* upper approximation.
*/
Expand Down Expand Up @@ -636,11 +648,11 @@ class CheckCaptures extends Recheck, SymTransformer:
val meth = tree.fun.symbol
if meth == defn.Caps_unsafeAssumePure then
val arg :: Nil = tree.args: @unchecked
val argType0 = recheck(arg, pt.capturing(CaptureSet.universal))
val argType0 = recheck(arg, pt.stripCapturing.capturing(CaptureSet.universal))
val argType =
if argType0.captureSet.isAlwaysEmpty then argType0
else argType0.widen.stripCapturing
capt.println(i"rechecking $arg with $pt: $argType")
capt.println(i"rechecking unsafeAssumePure of $arg with $pt: $argType")
super.recheckFinish(argType, tree, pt)
else
val res = super.recheckApply(tree, pt)
Expand All @@ -660,6 +672,9 @@ class CheckCaptures extends Recheck, SymTransformer:
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
markFree(argType.deepCaptureSet, arg.srcPos)
case _ =>
if formal.containsCap then
arg.updNuType(freshenedFormal)
sepCheckable += arg
argType

/** Map existential captures in result to `cap` and implement the following
Expand Down Expand Up @@ -1785,6 +1800,7 @@ class CheckCaptures extends Recheck, SymTransformer:
end checker

checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
if ccConfig.useFresh then SepChecker(this).traverse(unit)
if !ctx.reporter.errorsReported then
// We dont report errors here if previous errors were reported, because other
// errors often result in bad applied types, but flagging these bad types gives
Expand Down
116 changes: 116 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package dotty.tools
package dotc
package cc
import ast.tpd
import collection.mutable

import core.*
import Symbols.*, Types.*
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
import CaptureSet.{Refs, emptySet}
import config.Printers.capt
import StdNames.nme

class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
import tpd.*
import checker.*

extension (cs: CaptureSet)
def footprint(using Context): CaptureSet =
def recur(elems: CaptureSet.Refs, newElems: List[CaptureRef]): CaptureSet.Refs = newElems match
case newElem :: newElems1 =>
val superElems = newElem.captureSetOfInfo.elems.filter: superElem =>
!superElem.isMaxCapability && !elems.contains(superElem)
recur(superElems ++ elems, superElems.toList ++ newElems1)
case Nil => elems
val elems: CaptureSet.Refs = cs.elems.filter(!_.isMaxCapability)
CaptureSet(recur(elems, elems.toList))

def overlapWith(other: CaptureSet)(using Context): CaptureSet.Refs =
val refs1 = cs.elems
val refs2 = other.elems
def common(refs1: CaptureSet.Refs, refs2: CaptureSet.Refs) =
refs1.filter: ref =>
ref.isExclusive && refs2.exists(_.stripReadOnly eq ref)
common(refs1, refs2) ++ common(refs2, refs1)

private def hidden(elem: CaptureRef)(using Context): CaptureSet.Refs = elem match
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ hidden(hcs)
case ReadOnlyCapability(ref) => hidden(ref).map(_.readOnly)
case _ => emptySet

private def hidden(cs: CaptureSet)(using Context): CaptureSet.Refs =
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet

def hiddenByElem(elem: CaptureRef): CaptureSet.Refs =
if seen.add(elem) then elem match
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs)
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
case _ => emptySet
else emptySet

def recur(cs: CaptureSet): CaptureSet.Refs =
(emptySet /: cs.elems): (elems, elem) =>
elems ++ hiddenByElem(elem)

recur(cs)
end hidden

private def checkApply(fn: Tree, args: List[Tree])(using Context): Unit =
val fnCaptures = fn.nuType.deepCaptureSet

def captures(arg: Tree) =
val argType = arg.nuType
argType match
case AnnotatedType(formal1, ann) if ann.symbol == defn.UseAnnot =>
argType.deepCaptureSet
case _ =>
argType.captureSet

val argCaptures = args.map(captures)
capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = $argCaptures")
var footprint = argCaptures.foldLeft(fnCaptures.footprint): (fp, ac) =>
fp ++ ac.footprint
val paramNames = fn.nuType.widen match
case MethodType(pnames) => pnames
case _ => args.indices.map(nme.syntheticParamName(_))
for (arg, ac, pname) <- args.lazyZip(argCaptures).lazyZip(paramNames) do
if arg.needsSepCheck then
val hiddenInArg = CaptureSet(hidden(ac))
//println(i"check sep $arg / $footprint / $hiddenInArg")
val overlap = hiddenInArg.footprint.overlapWith(footprint)
if !overlap.isEmpty then
def whatStr = if overlap.size == 1 then "this capability" else "these capabilities"
def funStr =
if fn.symbol.exists then i"${fn.symbol}"
else "the function"
report.error(
em"""Separation failure: argument to capture-polymorphic parameter $pname: ${arg.nuType}
|captures ${CaptureSet(overlap)} and also passes $whatStr separately to $funStr""",
arg.srcPos)
footprint ++= hiddenInArg

private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match
case Apply(fn, args) => traverseApply(fn, args :: argss)
case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments
case _ =>
if argss.nestedExists(_.needsSepCheck) then
checkApply(tree, argss.flatten)

def traverse(tree: Tree)(using Context): Unit =
tree match
case tree: GenericApply =>
if tree.symbol != defn.Caps_unsafeAssumeSeparate then
tree.tpe match
case _: MethodOrPoly =>
case _ => traverseApply(tree, Nil)
traverseChildren(tree)
case _ =>
traverseChildren(tree)
end SepChecker






1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,7 @@ class Definitions {
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
@tu lazy val Caps_Mutable: ClassSymbol = requiredClass("scala.caps.Mutable")
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4178,7 +4178,7 @@ object Types extends TypeUtils {
tl => params.map(p => tl.integrate(params, adaptParamInfo(p))),
tl => tl.integrate(params, resultType))

/** Adapt info of parameter symbol to be integhrated into corresponding MethodType
/** Adapt info of parameter symbol to be integrated into corresponding MethodType
* using the scheme described in `fromSymbols`.
*/
def adaptParamInfo(param: Symbol, pinfo: Type)(using Context): Type =
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ abstract class Recheck extends Phase, SymTransformer:
* from the current type.
*/
def setNuType(tpe: Type): Unit =
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then nuTypes(tree) = tpe
if nuTypes.lookup(tree) == null then updNuType(tpe)

/** Set new type of the tree unconditionally. */
def updNuType(tpe: Type): Unit =
if tpe ne tree.tpe then nuTypes(tree) = tpe

/** The new type of the tree, or if none was installed, the original type */
def nuType(using Context): Type =
Expand Down
5 changes: 5 additions & 0 deletions library/src/scala/caps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,9 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
*/
def unsafeAssumePure: T = x

/** A wrapper around code for which separation checks are suppressed.
*/
def unsafeAssumeSeparate[T](op: T): T = op

end unsafe
end caps
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ object IndexedSeqView {

@SerialVersionUID(3L)
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)
extends SeqView.Concat[A](prefix, suffix) with IndexedSeqView[A]
extends SeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with IndexedSeqView[A]

@SerialVersionUID(3L)
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
remaining -= 1
scout = scout.tail
}
dropRightState(scout)
caps.unsafe.unsafeAssumeSeparate(dropRightState(scout))
}
}

Expand Down Expand Up @@ -879,6 +879,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
if (!cursor.stateDefined) b.append(sep).append("<not computed>")
} else {
@inline def same(a: LazyListIterable[A]^, b: LazyListIterable[A]^): Boolean = (a eq b) || (a.state eq b.state)
// !!!CC with qualifiers, same should have cap.rd parameters
// Cycle.
// If we have a prefix of length P followed by a cycle of length C,
// the scout will be at position (P%C) in the cycle when the cursor
Expand All @@ -890,7 +891,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
// the start of the loop.
var runner = this
var k = 0
while (!same(runner, scout)) {
while (!caps.unsafe.unsafeAssumeSeparate(same(runner, scout))) {
runner = runner.tail
scout = scout.tail
k += 1
Expand All @@ -900,11 +901,11 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
// everything once. If cursor is already at beginning, we'd better
// advance one first unless runner didn't go anywhere (in which case
// we've already looped once).
if (same(cursor, scout) && (k > 0)) {
if (caps.unsafe.unsafeAssumeSeparate(same(cursor, scout)) && (k > 0)) {
appendCursorElement()
cursor = cursor.tail
}
while (!same(cursor, scout)) {
while (!caps.unsafe.unsafeAssumeSeparate(same(cursor, scout))) {
appendCursorElement()
cursor = cursor.tail
}
Expand Down Expand Up @@ -1052,7 +1053,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
val head = it.next()
rest = rest.tail
restRef = rest // restRef.elem = rest
sCons(head, newLL(stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state)))
sCons(head, newLL(
caps.unsafe.unsafeAssumeSeparate(
stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state))))
} else State.Empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private[mutable] object CheckedIndexedSeqView {

@SerialVersionUID(3L)
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)(protected val mutationCount: () => Int)
extends IndexedSeqView.Concat[A](prefix, suffix) with CheckedIndexedSeqView[A]
extends IndexedSeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with CheckedIndexedSeqView[A]

@SerialVersionUID(3L)
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)(protected val mutationCount: () => Int)
Expand Down
5 changes: 5 additions & 0 deletions tests/neg-custom-args/captures/cc-dep-param.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Error: tests/neg-custom-args/captures/cc-dep-param.scala:8:6 --------------------------------------------------------
8 | foo(a, useA) // error: separation failure
| ^
| Separation failure: argument to capture-polymorphic parameter x$0: Foo[Int]^
| captures {a} and also passes this capability separately to method foo
8 changes: 8 additions & 0 deletions tests/neg-custom-args/captures/cc-dep-param.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import language.experimental.captureChecking

trait Foo[T]
def test(): Unit =
val a: Foo[Int]^ = ???
val useA: () ->{a} Unit = ???
def foo[X](x: Foo[X]^, op: () ->{x} Unit): Unit = ???
foo(a, useA) // error: separation failure
12 changes: 6 additions & 6 deletions tests/neg-custom-args/captures/cc-subst-param-exact.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@ trait Ref[T] { def set(x: T): T }
def test() = {

def swap[T](x: Ref[T]^)(y: Ref[T]^{x}): Unit = ???
def foo[T](x: Ref[T]^): Unit =
def foo[T](x: Ref[T]^{cap.rd}): Unit =
swap(x)(x)

def bar[T](x: () => Ref[T]^)(y: Ref[T]^{x}): Unit =
def bar[T](x: () => Ref[T]^{cap.rd})(y: Ref[T]^{x}): Unit =
swap(x())(y) // error

def baz[T](x: Ref[T]^)(y: Ref[T]^{x}): Unit =
def baz[T](x: Ref[T]^{cap.rd})(y: Ref[T]^{x}): Unit =
swap(x)(y)
}

trait IO
type Op = () -> Unit
def test2(c: IO^, f: Op^{c}) = {
def run(io: IO^)(op: Op^{io}): Unit = op()
run(c)(f)
run(c)(f) // error: separation failure

def bad(getIO: () => IO^, g: Op^{getIO}): Unit =
run(getIO())(g) // error
run(getIO())(g) // error // error: separation failure
}

def test3() = {
def run(io: IO^)(op: Op^{io}): Unit = ???
val myIO: IO^ = ???
val myOp: Op^{myIO} = ???
run(myIO)(myOp)
run(myIO)(myOp) // error: separation failure
}
9 changes: 9 additions & 0 deletions tests/neg-custom-args/captures/filevar-expanded.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Error: tests/neg-custom-args/captures/filevar-expanded.scala:34:19 --------------------------------------------------
34 | withFile(io3): f => // error: separation failure
| ^
| Separation failure: argument to capture-polymorphic parameter x$1: (f: test2.File^{io3}) => Unit
| captures {io3} and also passes this capability separately to method withFile
35 | val o = Service(io3)
36 | o.file = f // this is a bit dubious. It's legal since we treat class refinements
37 | // as capture set variables that can be made to include refs coming from outside.
38 | o.log
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object test2:
op(new File)

def test(io3: IO^) =
withFile(io3): f =>
withFile(io3): f => // error: separation failure
val o = Service(io3)
o.file = f // this is a bit dubious. It's legal since we treat class refinements
// as capture set variables that can be made to include refs coming from outside.
Expand Down
5 changes: 5 additions & 0 deletions tests/neg-custom-args/captures/function-combinators.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Error: tests/neg-custom-args/captures/function-combinators.scala:15:22 ----------------------------------------------
15 | val b2 = g1.andThen(g1); // error: separation failure
| ^^
| Separation failure: argument to capture-polymorphic parameter x$0: Int => Int
| captures {ctx1} and also passes this capability separately to method andThen
30 changes: 30 additions & 0 deletions tests/neg-custom-args/captures/function-combinators.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class ContextClass
type Context = ContextClass^
import caps.unsafe.unsafeAssumePure

def Test(using ctx1: Context, ctx2: Context) =
val f: Int => Int = identity
val g1: Int ->{ctx1} Int = identity
val g2: Int ->{ctx2} Int = identity
val h: Int -> Int = identity
val a1 = f.andThen(f); val _: Int ->{f} Int = a1
val a2 = f.andThen(g1); val _: Int ->{f, g1} Int = a2
val a3 = f.andThen(g2); val _: Int ->{f, g2} Int = a3
val a4 = f.andThen(h); val _: Int ->{f} Int = a4
val b1 = g1.andThen(f); val _: Int ->{f, g1} Int = b1
val b2 = g1.andThen(g1); // error: separation failure
val _: Int ->{g1} Int = b2
val b3 = g1.andThen(g2); val _: Int ->{g1, g2} Int = b3
val b4 = g1.andThen(h); val _: Int ->{g1} Int = b4
val c1 = h.andThen(f); val _: Int ->{f} Int = c1
val c2 = h.andThen(g1); val _: Int ->{g1} Int = c2
val c3 = h.andThen(g2); val _: Int ->{g2} Int = c3
val c4 = h.andThen(h); val _: Int -> Int = c4

val f2: (Int, Int) => Int = _ + _
val f2c = f2.curried; val _: Int -> Int ->{f2} Int = f2c
val f2t = f2.tupled; val _: ((Int, Int)) ->{f2} Int = f2t

val f3: (Int, Int, Int) => Int = ???
val f3c = f3.curried; val _: Int -> Int -> Int ->{f3} Int = f3c
val f3t = f3.tupled; val _: ((Int, Int, Int)) ->{f3} Int = f3t
Loading

0 comments on commit 8cbc022

Please sign in to comment.