diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index a987b8788dd1..bf1306024e5f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -104,7 +104,7 @@ sealed abstract class CaptureSet extends Showable: extension (x: CaptureRef) private def subsumes(y: CaptureRef) = (x eq y) || y.match - case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ? + case y: TermRef => y.prefix eq x case _ => false /** {x} <:< this where <:< is subcapturing, but treating all variables diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index f9af3c38de8a..d100bdd36f21 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -126,15 +126,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: bindType def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = - if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym) + if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) } - - def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = - recheck(tree, pt) + inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) @@ -358,21 +355,22 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: // Don't report closure nodes, since their span is a point; wait instead // for enclosing block to preduce an error case _ => - val actual = tpe.widenExpr - val expected = pt.widenExpr - //println(i"check conforms $actual <:< $expected") - val isCompatible = - actual <:< expected - || expected.isRepeatedParam - && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) - if !isCompatible then - recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected") - err.typeMismatch(tree.withType(tpe), expected) - else if debugSuccesses then - tree match - case _: Ident => - println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") - case _ => + checkConformsExpr(tpe, tpe.widenExpr, pt.widenExpr, tree) + + def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit = + //println(i"check conforms $actual <:< $expected") + val isCompatible = + actual <:< expected + || expected.isRepeatedParam + && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) + if !isCompatible then + recheckr.println(i"conforms failed for ${tree}: $original vs $expected") + err.typeMismatch(tree.withType(original), expected) + else if debugSuccesses then + tree match + case _: Ident => + println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") + case _ => def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 994df9a382c4..2e81576eb317 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -16,6 +16,7 @@ import transform.Recheck import Recheck.* import scala.collection.mutable import CaptureSet.withCaptureSetsExplained +import StdNames.nme import reporting.trace object CheckCaptures: @@ -213,22 +214,6 @@ class CheckCaptures extends Recheck: interpolateVarsIn(tree.tpt) curEnv = saved - override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = - val pt1 = pt match - case CapturingType(core, refs, _) - if sym.owner.isClass && !sym.owner.isExtensibleClass - && refs.elems.contains(sym.owner.thisType) => - val paramCaptures = - sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) => - val pcs = p.info.captureSet - (cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst - } - val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet - pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures)) - case _ => - pt - recheck(tree, pt1) - override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = for param <- cls.paramGetters do if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then @@ -237,6 +222,8 @@ class CheckCaptures extends Recheck: param.srcPos) val saved = curEnv val localSet = capturedVars(cls) + for parent <- impl.parents do + checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos) if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv) try super.recheckClassDef(tree, impl, cls) finally curEnv = saved @@ -289,9 +276,34 @@ class CheckCaptures extends Recheck: finally curEnv = curEnv.outer recheckFinish(result, arg, pt) + /** A specialized implementation of the apply rule from https://github.com/lampepfl/dotty/discussions/14387: + * + * E |- f: Cf (Ra -> Cr Rr) + * E |- a: Ra + * ------------------------ + * E |- f a: Cr /\ {f} Rr + * + * Specialized for the case where `f` is a tracked and the arguments are pure. + * This replaces the previous rule #13657 while still allowing the code in pos/lazylists1.scala. + * We could consider generalizing to the case where the function arguments have non-empty + * capture sets as suggested in #14387, but that would make capture set computations more complex, + * so we should also evaluate the performance impact. + */ override def recheckApply(tree: Apply, pt: Type)(using Context): Type = includeCallCaptures(tree.symbol, tree.srcPos) - super.recheckApply(tree, pt) + super.recheckApply(tree, pt) match + case tp @ CapturingType(tp1, refs, kind) => + tree.fun match + case Select(qual, nme.apply) + if defn.isFunctionType(qual.tpe.widen) => + qual.tpe match + case ref: CaptureRef + if ref.isTracked && tree.args.forall(_.tpe.captureSet.isAlwaysEmpty) => + tp.derivedCapturingType(tp1, refs ** ref.singletonCaptureSet) + .showing(i"narrow $tree: $tp --> $result", capt) + case _ => tp + case _ => tp + case tp => tp override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = val res = super.recheck(tree, pt) @@ -319,6 +331,42 @@ class CheckCaptures extends Recheck: case _ => super.recheckFinish(tpe, tree, pt) + /** This method implements the rule outlined in #14390: + * When checking an expression `e: Ca Ta` against an expected type `Cx Tx` + * where the capture set of `Cx` contains this and any method inside the class + * `Cls` of `this` that contains `e` has only pure parameters, drop from `Ca` + * all references to variables or this references outside `Cls`. These are all + * accessed through this, so are already accounted for by `Cx`. + */ + override def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit = + def isPure(info: Type): Boolean = info match + case info: PolyType => isPure(info.resType) + case info: MethodType => info.paramInfos.forall(_.captureSet.isAlwaysEmpty) && isPure(info.resType) + case _ => true + def isPureContext(owner: Symbol, limit: Symbol): Boolean = + if owner == limit then true + else if !owner.exists then false + else isPure(owner.info) && isPureContext(owner.owner, limit) + val actual1 = (expected, actual.widen) match + case (CapturingType(ecore, erefs, _), actualw @ CapturingType(acore, arefs, _)) => + val arefs1 = (arefs /: erefs.elems) { (arefs1, eref) => + eref match + case eref: ThisType if isPureContext(ctx.owner, eref.cls) => + arefs1.filter { + case aref1: TermRef => !eref.cls.isContainedIn(aref1.symbol.owner) + case aref1: ThisType => !eref.cls.isContainedIn(aref1.cls) + case _ => true + } + case _ => + arefs1 + } + if arefs1 eq arefs then actual + else actualw.derivedCapturingType(acore, arefs1) + .showing(i"healing $actual --> $result", capt) + case _ => + actual + super.checkConformsExpr(original, actual1, expected, tree) + override def checkUnit(unit: CompilationUnit)(using Context): Unit = Setup(preRecheckPhase, thisPhase, recheckDef) .traverse(ctx.compilationUnit.tpdTree) diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check index 31624e437928..bdbef10de0d6 100644 --- a/tests/neg-custom-args/captures/lazylist.check +++ b/tests/neg-custom-args/captures/lazylist.check @@ -37,6 +37,6 @@ longer explanation available when compiling with `-explain` 17 | def tail = xs() // error: cannot have an inferred type | ^^^^^^^^^^^^^^^ | Non-local method tail cannot have an inferred result type - | {*} lazylists.LazyList[T] - | with non-empty capture set {*}. + | {LazyCons.this.xs} lazylists.LazyList[T] + | with non-empty capture set {LazyCons.this.xs}. | The type needs to be declared explicitly. diff --git a/tests/neg-custom-args/captures/lazylists1.check b/tests/neg-custom-args/captures/lazylists1.check index 29291c8044c0..1d23de2ff134 100644 --- a/tests/neg-custom-args/captures/lazylists1.check +++ b/tests/neg-custom-args/captures/lazylists1.check @@ -1,7 +1,7 @@ --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 ----------------------------------- -25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {xs, f} LazyList[A] - | Required: {Mapped.this, xs} LazyList[A] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:66 ----------------------------------- +25 | def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {xs, f} LazyList[A] + | Required: {Mapped.this, f} LazyList[A] longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists1.scala b/tests/neg-custom-args/captures/lazylists1.scala index 4091ee2c62ae..c6475223b783 100644 --- a/tests/neg-custom-args/captures/lazylists1.scala +++ b/tests/neg-custom-args/captures/lazylists1.scala @@ -22,6 +22,6 @@ extension [A](xs: {*} LazyList[A]) def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK - def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error new Mapped diff --git a/tests/neg-custom-args/captures/lazylists2.check b/tests/neg-custom-args/captures/lazylists2.check index a8f145ecd9d7..8cb825f02537 100644 --- a/tests/neg-custom-args/captures/lazylists2.check +++ b/tests/neg-custom-args/captures/lazylists2.check @@ -29,17 +29,29 @@ longer explanation available when compiling with `-explain` 32 | new Mapped longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 ----------------------------------- -41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error - | ^^^^^^^^^^^^^^ - | Found: {f} LazyList[B] - | Required: {xs} LazyList[B] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:36:4 ------------------------------------ +36 | final class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: {xs} LazyList[B] +37 | this: ({xs} Mapped) => +38 | def isEmpty = false +39 | def head: B = f(xs.head) +40 | def tail: {this} LazyList[B] = xs.tail.map(f) +41 | new Mapped longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 ----------------------------------- -59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error - | ^^^^^^^^^^^^^^ - | Found: {f} LazyList[B] - | Required: {Mapped.this} LazyList[B] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:54:4 ------------------------------------ +54 | class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: LazyList[B] +55 | this: ({xs, f} Mapped) => +56 | def isEmpty = false +57 | def head: B = f(xs.head) +58 | def tail: {this} LazyList[B] = xs.tail.map(f) +59 | class Mapped2 extends Mapped: +60 | this: Mapped => +61 | new Mapped2 longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists2.scala b/tests/neg-custom-args/captures/lazylists2.scala index b9ebb0a7a9f0..fc1ab768047a 100644 --- a/tests/neg-custom-args/captures/lazylists2.scala +++ b/tests/neg-custom-args/captures/lazylists2.scala @@ -33,12 +33,12 @@ extension [A](xs: {*} LazyList[A]) new Mapped def map3[B](f: A => B): {xs} LazyList[B] = - final class Mapped extends LazyList[B]: + final class Mapped extends LazyList[B]: // error this: ({xs} Mapped) => def isEmpty = false def head: B = f(xs.head) - def tail: {this} LazyList[B] = xs.tail.map(f) // error + def tail: {this} LazyList[B] = xs.tail.map(f) new Mapped def map4[B](f: A => B): {xs} LazyList[B] = @@ -51,12 +51,12 @@ extension [A](xs: {*} LazyList[A]) new Mapped def map5[B](f: A => B): LazyList[B] = - class Mapped extends LazyList[B]: + class Mapped extends LazyList[B]: // error this: ({xs, f} Mapped) => def isEmpty = false def head: B = f(xs.head) - def tail: {this} LazyList[B] = xs.tail.map(f) // error + def tail: {this} LazyList[B] = xs.tail.map(f) class Mapped2 extends Mapped: this: Mapped => new Mapped2 diff --git a/tests/pos-custom-args/captures/lazylists-exceptions.scala b/tests/pos-custom-args/captures/lazylists-exceptions.scala new file mode 100644 index 000000000000..6b2378906b8e --- /dev/null +++ b/tests/pos-custom-args/captures/lazylists-exceptions.scala @@ -0,0 +1,68 @@ +import language.experimental.saferExceptions +import annotation.unchecked.uncheckedVariance + +trait LazyList[+A]: + this: {*} LazyList[A] => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]: + this: {*} LazyList[T] => + + var forced = false + var cache: {this} LazyList[T @uncheckedVariance] = compiletime.uninitialized + + private def force = + if !forced then + cache = xs() + forced = true + cache + + def isEmpty = false + def head = x + def tail: {this} LazyList[T] = force +end LazyCons + +extension [A](xs: {*} LazyList[A]) + def map[B](f: A => B): {xs, f} LazyList[B] = + if xs.isEmpty then LazyNil + else LazyCons(f(xs.head), () => xs.tail.map(f)) + + def filter(p: A => Boolean): {xs, p} LazyList[A] = + if xs.isEmpty then LazyNil + else if p(xs.head) then LazyCons(xs.head, () => xs.tail.filter(p)) + else xs.tail.filter(p) + + def concat(ys: {*} LazyList[A]): {xs, ys} LazyList[A] = + if xs.isEmpty then ys + else LazyCons(xs.head, () => xs.tail.concat(ys)) +end extension + +class Ex1 extends Exception +class Ex2 extends Exception + +def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) = + val xs = LazyCons(1, () => LazyNil) + + def f(x: Int): Int throws Ex1 = + if x < 0 then throw Ex1() + x * x + + def g(x: Int): Int throws Ex1 = + if x < 0 then throw Ex1() + x * x + + def x1 = xs.map(f) + def x1c: {cap1} LazyList[Int] = x1 + + def x2 = x1.concat(xs.map(g).filter(_ > 0)) + def x2c: {cap1, cap2} LazyList[Int] = x2 + + diff --git a/tests/pos-custom-args/captures/lazylists.scala b/tests/pos-custom-args/captures/lazylists.scala index c566bea8dd64..fd130c87cdea 100644 --- a/tests/pos-custom-args/captures/lazylists.scala +++ b/tests/pos-custom-args/captures/lazylists.scala @@ -21,7 +21,6 @@ extension [A](xs: {*} LazyList[A]) def isEmpty = false def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK - def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK if xs.isEmpty then LazyNil else new Mapped diff --git a/tests/pos-custom-args/captures/lazylists1.scala b/tests/pos-custom-args/captures/lazylists1.scala index 2dbb5ac232e2..c203450499e7 100644 --- a/tests/pos-custom-args/captures/lazylists1.scala +++ b/tests/pos-custom-args/captures/lazylists1.scala @@ -7,29 +7,36 @@ trait LazyList[+A]: def isEmpty: Boolean def head: A def tail: {this} LazyList[A] + def concat[B >: A](other: {*} LazyList[B]): {this, other} LazyList[B] object LazyNil extends LazyList[Nothing]: def isEmpty: Boolean = true def head = ??? def tail = ??? + def concat[B](other: {*} LazyList[B]): {other} LazyList[B] = other -final class LazyCons[+T](val x: T, val xs: Int => {*} LazyList[T]) extends LazyList[T]: - this: {*} LazyList[T] => +final class LazyCons[+A](val x: A, val xs: () => {*} LazyList[A]) extends LazyList[A]: + this: {*} LazyList[A] => def isEmpty = false def head = x - def tail: {this} LazyList[T] = xs(0) + def tail: {this} LazyList[A] = xs() + def concat[B >: A](other: {*} LazyList[B]): {this, other} LazyList[B] = + LazyCons(head, () => tail.concat(other)) extension [A](xs: {*} LazyList[A]) def map[B](f: A => B): {xs, f} LazyList[B] = if xs.isEmpty then LazyNil - else LazyCons(f(xs.head), x => xs.tail.map(f)) + else LazyCons(f(xs.head), () => xs.tail.map(f)) def test(cap1: Cap, cap2: Cap) = def f(x: String): String = if cap1 == cap1 then "" else "a" def g(x: String): String = if cap2 == cap2 then "" else "a" - val xs = new LazyCons("", x => if f("") == f("") then LazyNil else LazyNil) + val xs = new LazyCons("", () => if f("") == f("") then LazyNil else LazyNil) val xsc: {cap1} LazyList[String] = xs val ys = xs.map(g) val ysc: {cap1, cap2} LazyList[String] = ys + val zs = new LazyCons("", () => if g("") == g("") then LazyNil else LazyNil) + val as = xs.concat(zs) + val asc: {xs, zs} LazyList[String] = as