Skip to content

Commit

Permalink
Improve Contains handling (#21361)
Browse files Browse the repository at this point in the history
Make use of enclosing Contains assumptions to improve the subsumes
logic.
  • Loading branch information
odersky authored Aug 12, 2024
2 parents d40da0b + b2292a8 commit b825f5a
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 23 deletions.
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,21 @@ extension (self: Type)
case _ =>
self

/** An extractor for a contains argument */
object ContainsImpl:
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =
tree.fun.tpe.widen match
case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl =>
tree.args match
case csArg :: refArg :: Nil => Some((csArg, refArg))
case _ => None
case _ => None

/** An extractor for a contains parameter */
object ContainsParam:
def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] =
sym.info.dealias match
case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil)
if tycon.typeSymbol == defn.Caps_ContainsTrait
&& cs.typeSymbol.isAbstractOrParamType => Some((cs, ref))
case _ => None
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType:
case x1: SingletonCaptureRef => x1.subsumes(y)
case _ => false
case x: TermParamRef => subsumesExistentially(x, y)
case x: TypeRef => assumedContainsOf(x).contains(y)
case _ => false

def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] =
CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty)

end CaptureRef

trait SingletonCaptureRef extends SingletonType, CaptureRef
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property}
import typer.ErrorReporting.Addenda
import TypeComparer.subsumesExistentially
import util.common.alwaysTrue
import scala.collection.mutable
import scala.collection.{mutable, immutable}
import CCState.*

/** A class for capture sets. Capture sets can be constants or variables.
Expand Down Expand Up @@ -1125,6 +1125,12 @@ object CaptureSet:
foldOver(cs, t)
collect(CaptureSet.empty, tp)

type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]]
val AssumedContains: Property.Key[AssumedContains] = Property.Key()

def assumedContains(using Context): AssumedContains =
ctx.property(AssumedContains).getOrElse(immutable.Map.empty)

private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key()

/** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable
Expand Down
46 changes: 26 additions & 20 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer:
i"Sealed type variable $pname", "be instantiated to",
i"This is often caused by a local capability$where\nleaking as part of its result.",
tree.srcPos)
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
if meth == defn.Caps_containsImpl then checkContains(tree)
res
try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
finally checkContains(tree)
end recheckTypeApply

/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
* capability and assert that `{r} <:CS`.
*/
def checkContains(tree: TypeApply)(using Context): Unit =
tree.fun.knownType.widen match
case fntpe: PolyType =>
tree.args match
case csArg :: refArg :: Nil =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>
case _ =>
def checkContains(tree: TypeApply)(using Context): Unit = tree match
case ContainsImpl(csArg, refArg) =>
val cs = csArg.knownType.captureSet
val ref = refArg.knownType
capt.println(i"check contains $cs , $ref")
ref match
case ref: CaptureRef if ref.isTracked =>
checkElem(ref, cs, tree.srcPos)
case _ =>
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
case _ =>

override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
inNestedLevel(super.recheckBlock(tree, pt))
Expand Down Expand Up @@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer:
val localSet = capturedVars(sym)
if !localSet.isAlwaysEmpty then
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv)

// ctx with AssumedContains entries for each Contains parameter
val bodyCtx =
var ac = CaptureSet.assumedContains
for paramSyms <- sym.paramSymss do
for case ContainsParam(cs, ref) <- paramSyms do
ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref)
if ac.isEmpty then ctx
else ctx.withProperty(CaptureSet.AssumedContains, Some(ac))

inNestedLevel: // TODO: needed here?
try checkInferredResult(super.recheckDefDef(tree, sym), tree)
try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree)
finally
if !sym.isAnonymousFunction then
// Anonymous functions propagate their type to the enclosing environment
// so it is not in general sound to interpolate their types.
interpolateVarsIn(tree.tpt)
curEnv = saved

end recheckDefDef

/** If val or def definition with inferred (result) type is visible
* in other compilation units, check that the actual inferred type
* conforms to the expected type where all inferred capture sets are dropped.
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ class Definitions {
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")

@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")
Expand Down
11 changes: 10 additions & 1 deletion tests/pos-custom-args/captures/i21313.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import caps.CapSet

trait Async:
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T =
val x: Async^{this} = ???
val y: Async^{Cap^} = x
val ac: Async^ = ???
def f(using caps.Contains[Cap, ac.type]) =
val x2: Async^{this} = ???
val y2: Async^{Cap^} = x2
val x3: Async^{ac} = ???
val y3: Async^{Cap^} = x3
???

trait Source[+T, Cap^]:
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.
Expand Down
8 changes: 8 additions & 0 deletions tests/run/Providers.check
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ Executing query: insert into subscribers(name, email) values Daniel daniel@Rockt
You've just been subscribed to RockTheJVM. Welcome, Martin
Acquired connection
Executing query: insert into subscribers(name, email) values Martin [email protected]

Injected2
You've just been subscribed to RockTheJVM. Welcome, Daniel
Acquired connection
Executing query: insert into subscribers(name, email) values Daniel [email protected]
You've just been subscribed to RockTheJVM. Welcome, Martin
Acquired connection
Executing query: insert into subscribers(name, email) values Martin [email protected]
52 changes: 52 additions & 0 deletions tests/run/Providers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ end Providers
Explicit().test()
println(s"\nInjected")
Injected().test()
println(s"\nInjected2")
Injected2().test()

/** Demonstrator for explicit dependency construction */
class Explicit:
Expand Down Expand Up @@ -173,5 +175,55 @@ class Injected:
end explicit
end Injected

/** Injected with builders in companion objects */
class Injected2:
import Providers.*

case class User(name: String, email: String)

class UserSubscription(emailService: EmailService, db: UserDatabase):
def subscribe(user: User) =
emailService.email(user)
db.insert(user)
object UserSubscription:
def apply()(using Provider[(EmailService, UserDatabase)]): UserSubscription =
new UserSubscription(provided[EmailService], provided[UserDatabase])

class EmailService:
def email(user: User) =
println(s"You've just been subscribed to RockTheJVM. Welcome, ${user.name}")

class UserDatabase(pool: ConnectionPool):
def insert(user: User) =
pool.get().runQuery(s"insert into subscribers(name, email) values ${user.name} ${user.email}")
object UserDatabase:
def apply()(using Provider[(ConnectionPool)]): UserDatabase =
new UserDatabase(provided[ConnectionPool])

class ConnectionPool(n: Int):
def get(): Connection =
println(s"Acquired connection")
Connection()

class Connection():
def runQuery(query: String): Unit =
println(s"Executing query: $query")

def test() =
given Provider[EmailService] = provide(EmailService())
given Provider[ConnectionPool] = provide(ConnectionPool(10))
given Provider[UserDatabase] = provide(UserDatabase())
given Provider[UserSubscription] = provide(UserSubscription())

def subscribe(user: User)(using Provider[UserSubscription]) =
val sub = UserSubscription()
sub.subscribe(user)

subscribe(User("Daniel", "[email protected]"))
subscribe(User("Martin", "[email protected]"))
end test
end Injected2




0 comments on commit b825f5a

Please sign in to comment.