diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index c42e8f71246d..b7ad12369b88 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -521,12 +521,17 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def captureRoot(using Context): Select = Select(scalaDot(nme.caps), nme.CAPTURE_ROOT) - def captureRootIn(using Context): Select = - Select(scalaDot(nme.caps), nme.capIn) - def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated = Annotated(parent, New(scalaAnnotationDot(annotName), List(refs))) + def makeCapsOf(id: Ident)(using Context): Tree = + TypeApply(Select(scalaDot(nme.caps), nme.capsOf), id :: Nil) + + def makeCapsBound()(using Context): Tree = + makeRetaining( + Select(scalaDot(nme.caps), tpnme.CapSet), + Nil, tpnme.retainsCap) + def makeConstructor(tparams: List[TypeDef], vparamss: List[List[ValDef]], rhs: Tree = EmptyTree)(using Context): DefDef = DefDef(nme.CONSTRUCTOR, joinParams(tparams, vparamss), TypeTree(), rhs) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 4dda8f1803e0..09a56fb1b359 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -131,6 +131,8 @@ extension (tree: Tree) def toCaptureRef(using Context): CaptureRef = tree match case ReachCapabilityApply(arg) => arg.toCaptureRef.reach + case CapsOfApply(arg) => + arg.toCaptureRef case _ => tree.tpe match case ref: CaptureRef if ref.isTrackableRef => ref @@ -145,7 +147,7 @@ extension (tree: Tree) case Some(refs) => refs case None => val refs = CaptureSet(tree.retainedElems.map(_.toCaptureRef)*) - .showing(i"toCaptureSet $tree --> $result", capt) + //.showing(i"toCaptureSet $tree --> $result", capt) tree.putAttachment(Captures, refs) refs @@ -526,6 +528,14 @@ object ReachCapabilityApply: case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg) case _ => None +/** An extractor for `caps.capsOf[X]`, which is used to express a generic capture set + * as a tree in a @retains annotation. + */ +object CapsOfApply: + def unapply(tree: TypeApply)(using Context): Option[Tree] = tree match + case TypeApply(capsOf, arg :: Nil) if capsOf.symbol == defn.Caps_capsOf => Some(arg) + case _ => None + class AnnotatedCapability(annot: Context ?=> ClassSymbol): def apply(tp: Type)(using Context) = AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan)) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 98dc5db878e0..0b19f75f14d0 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -879,7 +879,9 @@ object CaptureSet: val r1 = tm(r) val upper = r1.captureSet def isExact = - upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1) + upper.isAlwaysEmpty + || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1) + || r.derivesFrom(defn.Caps_CapSet) if variance > 0 || isExact then upper else if variance < 0 then CaptureSet.empty else upper.maybe diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 3981dcbb34a2..8eb2f2420369 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -123,16 +123,24 @@ object CheckCaptures: case _: SingletonType => report.error(em"Singleton type $parent cannot have capture set", parent.srcPos) case _ => + def check(elem: Tree, pos: SrcPos): Unit = elem.tpe match + case ref: CaptureRef => + if !ref.isTrackableRef then + report.error(em"$elem cannot be tracked since it is not a parameter or local value", pos) + case tpe => + report.error(em"$elem: $tpe is not a legal element of a capture set", pos) for elem <- ann.retainedElems do - val elem1 = elem match - case ReachCapabilityApply(arg) => arg - case _ => elem - elem1.tpe match - case ref: CaptureRef => - if !ref.isTrackableRef then - report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos) - case tpe => - report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos) + elem match + case CapsOfApply(arg) => + def isLegalCapsOfArg = + arg.symbol.isAbstractOrParamType && arg.symbol.info.derivesFrom(defn.Caps_CapSet) + if !isLegalCapsOfArg then + report.error( + em"""$arg is not a legal prefix for `^` here, + |is must be a type parameter or abstract type with a caps.CapSet upper bound.""", + elem.srcPos) + case ReachCapabilityApply(arg) => check(arg, elem.srcPos) + case _ => check(elem, elem.srcPos) /** Report an error if some part of `tp` contains the root capability in its capture set * or if it refers to an unsealed type parameter that could possibly be instantiated with diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index f588094fbdf3..6927983ad196 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -384,7 +384,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: sym.updateInfo(thisPhase, info, newFlagsFor(sym)) toBeUpdated -= sym sym.namedType match - case ref: CaptureRef => ref.invalidateCaches() // TODO: needed? + case ref: CaptureRef if ref.isTrackableRef => ref.invalidateCaches() // TODO: needed? case _ => extension (sym: Symbol) def nextInfo(using Context): Type = diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ad80d0565f63..e8a853469f6f 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -991,8 +991,10 @@ class Definitions { @tu lazy val CapsModule: Symbol = requiredModule("scala.caps") @tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap") - @tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability") + @tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability") + @tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet") @tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability") + @tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf") @tu lazy val Caps_Exists = requiredClass("scala.caps.Exists") @tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe") @tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 6548b46186bb..d3e198a7e7a7 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -358,6 +358,7 @@ object StdNames { val AppliedTypeTree: N = "AppliedTypeTree" val ArrayAnnotArg: N = "ArrayAnnotArg" val CAP: N = "CAP" + val CapSet: N = "CapSet" val Constant: N = "Constant" val ConstantType: N = "ConstantType" val Eql: N = "Eql" @@ -441,8 +442,8 @@ object StdNames { val bytes: N = "bytes" val canEqual_ : N = "canEqual" val canEqualAny : N = "canEqualAny" - val capIn: N = "capIn" val caps: N = "caps" + val capsOf: N = "capsOf" val captureChecking: N = "captureChecking" val checkInitialized: N = "checkInitialized" val classOf: N = "classOf" diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 958515a9e625..7088232c26e1 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2839,7 +2839,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def existentialVarsConform(tp1: Type, tp2: Type) = tp2 match case tp2: TermParamRef => tp1 match - case tp1: CaptureRef => subsumesExistentially(tp2, tp1) + case tp1: CaptureRef if tp1.isTrackableRef => subsumesExistentially(tp2, tp1) case _ => false case _ => false diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 7bcb7e453647..57618df2ebcd 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -2313,7 +2313,11 @@ object Types extends TypeUtils { override def captureSet(using Context): CaptureSet = val cs = captureSetOfInfo - if isTrackableRef && !cs.isAlwaysEmpty then singletonCaptureSet else cs + if isTrackableRef then + if cs.isAlwaysEmpty then cs else singletonCaptureSet + else dealias match + case _: (TypeRef | TypeParamRef) => CaptureSet.empty + case _ => cs end CaptureRef @@ -3032,7 +3036,7 @@ object Types extends TypeUtils { abstract case class TypeRef(override val prefix: Type, private var myDesignator: Designator) - extends NamedType { + extends NamedType, CaptureRef { type ThisType = TypeRef type ThisName = TypeName @@ -3081,6 +3085,9 @@ object Types extends TypeUtils { /** Hook that can be called from creation methods in TermRef and TypeRef */ def validated(using Context): this.type = this + + override def isTrackableRef(using Context) = + symbol.isAbstractOrParamType && derivesFrom(defn.Caps_CapSet) } final class CachedTermRef(prefix: Type, designator: Designator, hc: Int) extends TermRef(prefix, designator) { @@ -4841,7 +4848,8 @@ object Types extends TypeUtils { /** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to * refer to `TypeParamRef(binder, paramNum)`. */ - abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) extends ParamRef { + abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) + extends ParamRef, CaptureRef { type BT = TypeLambda def kindString: String = "Type" def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum) @@ -4861,6 +4869,8 @@ object Types extends TypeUtils { case bound: OrType => occursIn(bound.tp1, fromBelow) || occursIn(bound.tp2, fromBelow) case _ => false } + + override def isTrackableRef(using Context) = derivesFrom(defn.Caps_CapSet) } private final class TypeParamRefImpl(binder: TypeLambda, paramNum: Int) extends TypeParamRef(binder, paramNum) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 4c13934f3473..5c0374ceba87 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1541,7 +1541,7 @@ object Parsers { case _ => None } - /** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`] + /** CaptureRef ::= ident [`*` | `^`] | `this` */ def captureRef(): Tree = if in.token == THIS then simpleRef() @@ -1551,6 +1551,10 @@ object Parsers { in.nextToken() atSpan(startOffset(id)): PostfixOp(id, Ident(nme.CC_REACH)) + else if isIdent(nme.UPARROW) then + in.nextToken() + atSpan(startOffset(id)): + makeCapsOf(cpy.Ident(id)(id.name.toTypeName)) else id /** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking @@ -1968,7 +1972,7 @@ object Parsers { } /** SimpleType ::= SimpleLiteral - * | ‘?’ SubtypeBounds + * | ‘?’ TypeBounds * | SimpleType1 * | SimpleType ‘(’ Singletons ‘)’ -- under language.experimental.dependent, checked in Typer * Singletons ::= Singleton {‘,’ Singleton} @@ -2188,9 +2192,15 @@ object Parsers { inBraces(refineStatSeq()) /** TypeBounds ::= [`>:' Type] [`<:' Type] + * | `^` -- under captureChecking */ def typeBounds(): TypeBoundsTree = - atSpan(in.offset) { TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE)) } + atSpan(in.offset): + if in.isIdent(nme.UPARROW) && Feature.ccEnabled then + in.nextToken() + TypeBoundsTree(EmptyTree, makeCapsBound()) + else + TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE)) private def bound(tok: Int): Tree = if (in.token == tok) { in.nextToken(); toplevelTyp() } @@ -3384,7 +3394,6 @@ object Parsers { * * DefTypeParamClause::= ‘[’ DefTypeParam {‘,’ DefTypeParam} ‘]’ * DefTypeParam ::= {Annotation} - * [`sealed`] -- under captureChecking * id [HkTypeParamClause] TypeParamBounds * * TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index d3e411d5ea4d..6ca87127d3c9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2328,7 +2328,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val res = Throw(expr1).withSpan(tree.span) if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then // Record access to the CanThrow capabulity recovered in `cap` by wrapping - // the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotatoon. + // the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation. Typed(res, TypeTree( AnnotatedType(res.tpe, diff --git a/library/src/scala/caps.scala b/library/src/scala/caps.scala index 5ae5b860f501..967246041082 100644 --- a/library/src/scala/caps.scala +++ b/library/src/scala/caps.scala @@ -1,6 +1,6 @@ package scala -import annotation.experimental +import annotation.{experimental, compileTimeOnly} @experimental object caps: @@ -16,6 +16,12 @@ import annotation.experimental @deprecated("Use `Capability` instead") type Cap = Capability + /** Carrier trait for capture set type parameters */ + trait CapSet extends Any + + @compileTimeOnly("Should be be used only internally by the Scala compiler") + def capsOf[CS]: Any = ??? + /** Reach capabilities x* which appear as terms in @retains annotations are encoded * as `caps.reachCapability(x)`. When converted to CaptureRef types in capture sets * they are represented as `x.type @annotation.internal.reachCapability`. diff --git a/tests/pos/cc-poly-1.scala b/tests/pos/cc-poly-1.scala new file mode 100644 index 000000000000..69b7557b8466 --- /dev/null +++ b/tests/pos/cc-poly-1.scala @@ -0,0 +1,23 @@ +import language.experimental.captureChecking +import annotation.experimental +import caps.{CapSet, Capability} + +@experimental object Test: + + class C extends Capability + class D + + def f[X^](x: D^{X^}): D^{X^} = x + def g[X^](x: D^{X^}, y: D^{X^}): D^{X^} = x + + def test(c1: C, c2: C) = + val d: D^{c1, c2} = D() + val x = f[CapSet^{c1, c2}](d) + val _: D^{c1, c2} = x + val d1: D^{c1} = D() + val d2: D^{c2} = D() + val y = g(d1, d2) + val _: D^{d1, d2} = y + val _: D^{c1, c2} = y + + diff --git a/tests/pos/cc-poly-source.scala b/tests/pos/cc-poly-source.scala new file mode 100644 index 000000000000..939f1f682dc8 --- /dev/null +++ b/tests/pos/cc-poly-source.scala @@ -0,0 +1,36 @@ +import language.experimental.captureChecking +import annotation.experimental +import caps.{CapSet, Capability} + +@experimental object Test: + + class Label //extends Capability + + class Listener + + class Source[X^]: + private var listeners: Set[Listener^{X^}] = Set.empty + def register(x: Listener^{X^}): Unit = + listeners += x + + def allListeners: Set[Listener^{X^}] = listeners + + def test1(lbl1: Label^, lbl2: Label^) = + val src = Source[CapSet^{lbl1, lbl2}] + def l1: Listener^{lbl1} = ??? + val l2: Listener^{lbl2} = ??? + src.register{l1} + src.register{l2} + val ls = src.allListeners + val _: Set[Listener^{lbl1, lbl2}] = ls + + def test2(lbls: List[Label^]) = + def makeListener(lbl: Label^): Listener^{lbl} = ??? + val listeners = lbls.map(makeListener) + val src = Source[CapSet^{lbls*}] + for l <- listeners do + src.register(l) + val ls = src.allListeners + val _: Set[Listener^{lbls*}] = ls + +