From 98bc26d3a0aface96da1d2a24a907afcc31b1e51 Mon Sep 17 00:00:00 2001 From: odersky Date: Tue, 17 Aug 2021 12:38:12 +0200 Subject: [PATCH 01/24] Merge pull request #12971 from dotty-staging/add-rechecker Add recheck phase --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 + .../dotty/tools/dotc/config/Printers.scala | 1 + .../tools/dotc/config/ScalaSettings.scala | 1 + .../src/dotty/tools/dotc/core/NamerOps.scala | 20 ++ .../src/dotty/tools/dotc/core/Phases.scala | 3 + .../dotty/tools/dotc/core/TypeComparer.scala | 2 + .../tools/dotc/transform/PreRecheck.scala | 21 ++ .../dotty/tools/dotc/transform/Recheck.scala | 334 ++++++++++++++++++ .../tools/dotc/typer/RefineTypes.overflow | 0 .../dotty/tools/dotc/typer/TypeAssigner.scala | 18 +- compiler/test/dotc/pos-test-recheck.excludes | 3 + compiler/test/dotc/run-test-recheck.excludes | 0 compiler/test/dotty/tools/TestSources.scala | 4 + .../dotty/tools/dotc/CompilationTests.scala | 8 + .../tools/vulpix/TestConfiguration.scala | 1 + tests/neg/i6635a.scala | 19 + tests/pos/i6635.scala | 7 +- tests/pos/i6635a.scala | 14 + 18 files changed, 446 insertions(+), 12 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/transform/PreRecheck.scala create mode 100644 compiler/src/dotty/tools/dotc/transform/Recheck.scala create mode 100644 compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow create mode 100644 compiler/test/dotc/pos-test-recheck.excludes create mode 100644 compiler/test/dotc/run-test-recheck.excludes create mode 100644 tests/neg/i6635a.scala create mode 100644 tests/pos/i6635a.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index cbcc62b7fb6b..d5d77929d59d 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -103,6 +103,8 @@ class Compiler { new TupleOptimizations, // Optimize generic operations on tuples new LetOverApply, // Lift blocks from receivers of applications new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. + List(new PreRecheck) :: + List(new TestRecheck) :: List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index 8e13e50e59b7..b71e1e7f188a 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -38,6 +38,7 @@ object Printers { val pickling = noPrinter val quotePickling = noPrinter val plugins = noPrinter + val recheckr = noPrinter val refcheck = noPrinter val simplify = noPrinter val staging = noPrinter diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 4dccad86e98c..627c027bfd34 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -311,6 +311,7 @@ private sealed trait YSettings: val YcheckInit: Setting[Boolean] = BooleanSetting("-Ysafe-init", "Ensure safe initialization of objects") val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation") val YscalaRelease: Setting[String] = ChoiceSetting("-Yscala-release", "release", "Emit TASTy files that can be consumed by specified version of the compiler. The compilation will fail if for any reason valid TASTy cannot be produced (e.g. the code contains references to some parts of the standard library API that are missing in the older stdlib or uses language features unexpressible in the older version of TASTy format)", ScalaSettings.supportedScalaReleaseVersions, "", aliases = List("--Yscala-release")) + val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/core/NamerOps.scala b/compiler/src/dotty/tools/dotc/core/NamerOps.scala index 9444270ccb05..a5d3e95c8f3e 100644 --- a/compiler/src/dotty/tools/dotc/core/NamerOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NamerOps.scala @@ -177,4 +177,24 @@ object NamerOps: cls.registeredCompanion = modcls modcls.registeredCompanion = cls + /** For secondary constructors, make it known in the context that their type parameters + * are aliases of the class type parameters. This is done by (ab?)-using GADT constraints. + * See pos/i941.scala + */ + def linkConstructorParams(sym: Symbol)(using Context): Context = + if sym.isConstructor && !sym.isPrimaryConstructor then + sym.rawParamss match + case (tparams @ (tparam :: _)) :: _ if tparam.isType => + val rhsCtx = ctx.fresh.setFreshGADTBounds + rhsCtx.gadt.addToConstraint(tparams) + tparams.lazyZip(sym.owner.typeParams).foreach { (psym, tparam) => + val tr = tparam.typeRef + rhsCtx.gadt.addBound(psym, tr, isUpper = false) + rhsCtx.gadt.addBound(psym, tr, isUpper = true) + } + rhsCtx + case _ => + ctx + else ctx + end NamerOps diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 623286d837b3..63a70ad73a83 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -298,6 +298,9 @@ object Phases { /** If set, implicit search is enabled */ def allowsImplicitSearch: Boolean = false + /** If set equate Skolem types with underlying types */ + def widenSkolems: Boolean = false + /** List of names of phases that should precede this phase */ def runsAfter: Set[String] = Set.empty diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 8b4eab685f2a..4a0dc5f03410 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -739,6 +739,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling false } compareClassInfo + case tp2: SkolemType => + ctx.phase.widenSkolems && recur(tp1, tp2.info) || fourthTry case _ => fourthTry } diff --git a/compiler/src/dotty/tools/dotc/transform/PreRecheck.scala b/compiler/src/dotty/tools/dotc/transform/PreRecheck.scala new file mode 100644 index 000000000000..ab27bb3bb306 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/PreRecheck.scala @@ -0,0 +1,21 @@ +package dotty.tools.dotc +package transform + +import core.Phases.Phase +import core.DenotTransformers.IdentityDenotTransformer +import core.Contexts.{Context, ctx} + +/** A phase that precedes the rechecker and that allows installing + * new types for local symbols. + */ +class PreRecheck extends Phase, IdentityDenotTransformer: + + def phaseName: String = "preRecheck" + + override def isEnabled(using Context) = next.isEnabled + + override def changesBaseTypes: Boolean = true + + def run(using Context): Unit = () + + override def isCheckable = false diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala new file mode 100644 index 000000000000..76f89cb65757 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -0,0 +1,334 @@ +package dotty.tools +package dotc +package transform + +import core.* +import Symbols.*, Contexts.*, Types.*, ContextOps.*, Decorators.*, SymDenotations.* +import Flags.*, SymUtils.*, NameKinds.* +import ast.* +import Phases.Phase +import DenotTransformers.IdentityDenotTransformer +import NamerOps.{methodType, linkConstructorParams} +import NullOpsDecorator.stripNull +import typer.ErrorReporting.err +import typer.ProtoTypes.* +import typer.TypeAssigner.seqLitType +import typer.ConstFold +import config.Printers.recheckr +import util.Property +import StdNames.nme +import reporting.trace + +abstract class Recheck extends Phase, IdentityDenotTransformer: + thisPhase => + + import ast.tpd.* + + def preRecheckPhase = this.prev.asInstanceOf[PreRecheck] + + override def isEnabled(using Context) = ctx.settings.Yrecheck.value + override def changesBaseTypes: Boolean = true + + override def isCheckable = false + // TODO: investigate what goes wrong we Ycheck directly after rechecking. + // One failing test is pos/i583a.scala + + override def widenSkolems = true + + def run(using Context): Unit = + newRechecker().checkUnit(ctx.compilationUnit) + + def newRechecker()(using Context): Rechecker + + class Rechecker(ictx: Context): + val ta = ictx.typeAssigner + + extension (sym: Symbol) def updateInfo(newInfo: Type)(using Context): Unit = + if sym.info ne newInfo then + sym.copySymDenotation().installAfter(thisPhase) // reset + sym.copySymDenotation( + info = newInfo, + initFlags = + if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched + else sym.flags + ).installAfter(preRecheckPhase) + + /** Hook to be overridden */ + protected def reinfer(tp: Type)(using Context): Type = tp + + def reinferResult(info: Type)(using Context): Type = info match + case info: MethodOrPoly => + info.derivedLambdaType(resType = reinferResult(info.resultType)) + case _ => + reinfer(info) + + def enterDef(stat: Tree)(using Context): Unit = + val sym = stat.symbol + stat match + case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] => + sym.updateInfo(reinferResult(sym.info)) + case stat: Bind => + sym.updateInfo(reinferResult(sym.info)) + case _ => + + def constFold(tree: Tree, tp: Type)(using Context): Type = + val tree1 = tree.withType(tp) + val tree2 = ConstFold(tree1) + if tree2 ne tree1 then tree2.tpe else tp + + def recheckIdent(tree: Ident)(using Context): Type = + tree.tpe + + /** Keep the symbol of the `select` but re-infer its type */ + def recheckSelect(tree: Select)(using Context): Type = tree match + case Select(qual, name) => + val qualType = recheck(qual).widenIfUnstable + if name.is(OuterSelectName) then tree.tpe + else + //val pre = ta.maybeSkolemizePrefix(qualType, name) + val mbr = qualType.findMember(name, qualType, + excluded = if tree.symbol.is(Private) then EmptyFlags else Private + ).suchThat(tree.symbol ==) + constFold(tree, qualType.select(name, mbr)) + + def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match + case Bind(name, body) => + enterDef(tree) + recheck(body, pt) + val sym = tree.symbol + if sym.isType then sym.typeRef else sym.info + + def recheckLabeled(tree: Labeled, pt: Type)(using Context): Type = tree match + case Labeled(bind, expr) => + val bindType = recheck(bind, pt) + val exprType = recheck(expr, defn.UnitType) + bindType + + def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type = + if !tree.rhs.isEmpty then recheck(tree.rhs, tree.symbol.info) + sym.termRef + + def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type = + tree.paramss.foreach(_.foreach(enterDef)) + val rhsCtx = linkConstructorParams(sym) + if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then + recheck(tree.rhs, tree.symbol.localReturnType)(using rhsCtx) + sym.termRef + + def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = + recheck(tree.rhs) + sym.typeRef + + def recheckClassDef(tree: TypeDef, impl: Template, sym: ClassSymbol)(using Context): Type = + recheck(impl.constr) + impl.parentsOrDerived.foreach(recheck(_)) + recheck(impl.self) + recheckStats(impl.body) + sym.typeRef + + // Need to remap Object to FromJavaObject since it got lost in ElimRepeated + private def mapJavaArgs(formals: List[Type])(using Context): List[Type] = + val tm = new TypeMap: + def apply(t: Type) = t match + case t: TypeRef if t.symbol == defn.ObjectClass => defn.FromJavaObjectType + case _ => mapOver(t) + formals.mapConserve(tm) + + def recheckApply(tree: Apply, pt: Type)(using Context): Type = + recheck(tree.fun).widen match + case fntpe: MethodType => + assert(sameLength(fntpe.paramInfos, tree.args)) + val formals = + if tree.symbol.is(JavaDefined) then mapJavaArgs(fntpe.paramInfos) + else fntpe.paramInfos + def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match + case arg :: args1 => + val argType = recheck(arg, formals.head) + val formals1 = + if fntpe.isParamDependent + then formals.tail.map(_.substParam(prefs.head, argType)) + else formals.tail + argType :: recheckArgs(args1, formals1, prefs.tail) + case Nil => + assert(formals.isEmpty) + Nil + val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs) + constFold(tree, fntpe.instantiate(argTypes)) + + def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type = + recheck(tree.fun).widen match + case fntpe: PolyType => + assert(sameLength(fntpe.paramInfos, tree.args)) + val argTypes = tree.args.map(recheck(_)) + constFold(tree, fntpe.instantiate(argTypes)) + + def recheckTyped(tree: Typed)(using Context): Type = + val tptType = recheck(tree.tpt) + recheck(tree.expr, tptType) + tptType + + def recheckAssign(tree: Assign)(using Context): Type = + val lhsType = recheck(tree.lhs) + recheck(tree.rhs, lhsType.widen) + defn.UnitType + + def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type = + recheckStats(stats) + val exprType = recheck(expr, pt.dropIfProto) + TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) + + def recheckBlock(tree: Block, pt: Type)(using Context): Type = + recheckBlock(tree.stats, tree.expr, pt) + + def recheckInlined(tree: Inlined, pt: Type)(using Context): Type = + recheckBlock(tree.bindings, tree.expansion, pt) + + def recheckIf(tree: If, pt: Type)(using Context): Type = + recheck(tree.cond, defn.BooleanType) + recheck(tree.thenp, pt) | recheck(tree.elsep, pt) + + def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + if tree.tpt.isEmpty then + tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined)) + else + recheck(tree.tpt) + + def recheckMatch(tree: Match, pt: Type)(using Context): Type = + val selectorType = recheck(tree.selector) + val casesTypes = tree.cases.map(recheck(_, selectorType.widen, pt)) + TypeComparer.lub(casesTypes) + + def recheck(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = + recheck(tree.pat, selType) + recheck(tree.guard, defn.BooleanType) + recheck(tree.body, pt) + + def recheckReturn(tree: Return)(using Context): Type = + recheck(tree.expr, tree.from.symbol.returnProto) + defn.NothingType + + def recheckWhileDo(tree: WhileDo)(using Context): Type = + recheck(tree.cond, defn.BooleanType) + recheck(tree.body, defn.UnitType) + defn.UnitType + + def recheckTry(tree: Try, pt: Type)(using Context): Type = + val bodyType = recheck(tree.expr, pt) + val casesTypes = tree.cases.map(recheck(_, defn.ThrowableType, pt)) + val finalizerType = recheck(tree.finalizer, defn.UnitType) + TypeComparer.lub(bodyType :: casesTypes) + + def recheckSeqLiteral(tree: SeqLiteral, pt: Type)(using Context): Type = + val elemProto = pt.stripNull.elemType match + case NoType => WildcardType + case bounds: TypeBounds => WildcardType(bounds) + case elemtp => elemtp + val declaredElemType = recheck(tree.elemtpt) + val elemTypes = tree.elems.map(recheck(_, elemProto)) + seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes)) + + def recheckTypeTree(tree: TypeTree)(using Context): Type = tree match + case tree: InferredTypeTree => reinfer(tree.tpe) + case _ => tree.tpe + + def recheckAnnotated(tree: Annotated)(using Context): Type = + tree.tpe match + case tp: AnnotatedType => + val argType = recheck(tree.arg) + tp.derivedAnnotatedType(argType, tp.annot) + + def recheckAlternative(tree: Alternative, pt: Type)(using Context): Type = + val altTypes = tree.trees.map(recheck(_, pt)) + TypeComparer.lub(altTypes) + + def recheckPackageDef(tree: PackageDef)(using Context): Type = + recheckStats(tree.stats) + NoType + + def recheckStats(stats: List[Tree])(using Context): Unit = + stats.foreach(enterDef) + stats.foreach(recheck(_)) + + /** Recheck tree without adapting it, returning its new type. + * @param tree the original tree + * @param pt the expected result type + */ + def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + + def recheckNamed(tree: NameTree, pt: Type)(using Context): Type = + val sym = tree.symbol + tree match + case tree: Ident => recheckIdent(tree) + case tree: Select => recheckSelect(tree) + case tree: Bind => recheckBind(tree, pt) + case tree: ValDef => + if tree.isEmpty then NoType + else recheckValDef(tree, sym)(using ctx.localContext(tree, sym)) + case tree: DefDef => + recheckDefDef(tree, sym)(using ctx.localContext(tree, sym)) + case tree: TypeDef => + tree.rhs match + case impl: Template => + recheckClassDef(tree, impl, sym.asClass)(using ctx.localContext(tree, sym)) + case _ => + recheckTypeDef(tree, sym)(using ctx.localContext(tree, sym)) + case tree: Labeled => recheckLabeled(tree, pt) + + def recheckUnnamed(tree: Tree, pt: Type): Type = tree match + case tree: Apply => recheckApply(tree, pt) + case tree: TypeApply => recheckTypeApply(tree, pt) + case _: New | _: This | _: Super | _: Literal => tree.tpe + case tree: Typed => recheckTyped(tree) + case tree: Assign => recheckAssign(tree) + case tree: Block => recheckBlock(tree, pt) + case tree: If => recheckIf(tree, pt) + case tree: Closure => recheckClosure(tree, pt) + case tree: Match => recheckMatch(tree, pt) + case tree: Return => recheckReturn(tree) + case tree: WhileDo => recheckWhileDo(tree) + case tree: Try => recheckTry(tree, pt) + case tree: SeqLiteral => recheckSeqLiteral(tree, pt) + case tree: Inlined => recheckInlined(tree, pt) + case tree: TypeTree => recheckTypeTree(tree) + case tree: Annotated => recheckAnnotated(tree) + case tree: Alternative => recheckAlternative(tree, pt) + case tree: PackageDef => recheckPackageDef(tree) + case tree: Thicket => defn.NothingType + + try + val result = tree match + case tree: NameTree => recheckNamed(tree, pt) + case tree => recheckUnnamed(tree, pt) + checkConforms(result, pt, tree) + result + catch case ex: Exception => + println(i"error while rechecking $tree") + throw ex + } + end recheck + + def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match + case _: DefTree | EmptyTree | _: TypeTree => + case _ => + val actual = tpe.widenExpr + val expected = pt.widenExpr + val isCompatible = + actual <:< expected + || expected.isRepeatedParam + && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) + if !isCompatible then + println(i"err at ${ctx.phase}") + err.typeMismatch(tree.withType(tpe), pt) + + def checkUnit(unit: CompilationUnit)(using Context): Unit = + recheck(unit.tpdTree) + + end Rechecker +end Recheck + +class TestRecheck extends Recheck: + def phaseName: String = "recheck" + //override def isEnabled(using Context) = ctx.settings.YrefineTypes.value + def newRechecker()(using Context): Rechecker = Rechecker(ctx) + + diff --git a/compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow b/compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 3dcec413540f..7b67d828ddcb 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -17,7 +17,8 @@ import reporting._ import Checking.{checkNoPrivateLeaks, checkNoWildcard} trait TypeAssigner { - import tpd._ + import tpd.* + import TypeAssigner.* /** The qualifying class of a this or super with prefix `qual` (which might be empty). * @param packageOk The qualifier may refer to a package. @@ -435,13 +436,8 @@ trait TypeAssigner { if (cases.isEmpty) tree.withType(expr.tpe) else tree.withType(TypeComparer.lub(expr.tpe :: cases.tpes)) - def assignType(tree: untpd.SeqLiteral, elems: List[Tree], elemtpt: Tree)(using Context): SeqLiteral = { - val ownType = tree match { - case tree: untpd.JavaSeqLiteral => defn.ArrayOf(elemtpt.tpe) - case _ => if (ctx.erasedTypes) defn.SeqType else defn.SeqType.appliedTo(elemtpt.tpe) - } - tree.withType(ownType) - } + def assignType(tree: untpd.SeqLiteral, elems: List[Tree], elemtpt: Tree)(using Context): SeqLiteral = + tree.withType(seqLitType(tree, elemtpt.tpe)) def assignType(tree: untpd.SingletonTypeTree, ref: Tree)(using Context): SingletonTypeTree = tree.withType(ref.tpe) @@ -527,5 +523,9 @@ trait TypeAssigner { tree.withType(pid.symbol.termRef) } +object TypeAssigner extends TypeAssigner: + def seqLitType(tree: untpd.SeqLiteral, elemType: Type)(using Context) = tree match + case tree: untpd.JavaSeqLiteral => defn.ArrayOf(elemType) + case _ => if ctx.erasedTypes then defn.SeqType else defn.SeqType.appliedTo(elemType) + -object TypeAssigner extends TypeAssigner diff --git a/compiler/test/dotc/pos-test-recheck.excludes b/compiler/test/dotc/pos-test-recheck.excludes new file mode 100644 index 000000000000..e973b2cd529f --- /dev/null +++ b/compiler/test/dotc/pos-test-recheck.excludes @@ -0,0 +1,3 @@ +# Cannot compensate dealiasing due to false result dependency +i6635a.scala +i6682a.scala diff --git a/compiler/test/dotc/run-test-recheck.excludes b/compiler/test/dotc/run-test-recheck.excludes new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/compiler/test/dotty/tools/TestSources.scala b/compiler/test/dotty/tools/TestSources.scala index 4fbf0e9fc5dd..60070bb15af3 100644 --- a/compiler/test/dotty/tools/TestSources.scala +++ b/compiler/test/dotty/tools/TestSources.scala @@ -11,17 +11,21 @@ object TestSources { def posFromTastyBlacklistFile: String = "compiler/test/dotc/pos-from-tasty.blacklist" def posTestPicklingBlacklistFile: String = "compiler/test/dotc/pos-test-pickling.blacklist" + def posTestRecheckExcludesFile = "compiler/test/dotc/pos-test-recheck.excludes" def posFromTastyBlacklisted: List[String] = loadList(posFromTastyBlacklistFile) def posTestPicklingBlacklisted: List[String] = loadList(posTestPicklingBlacklistFile) + def posTestRecheckExcluded = loadList(posTestRecheckExcludesFile) // run tests lists def runFromTastyBlacklistFile: String = "compiler/test/dotc/run-from-tasty.blacklist" def runTestPicklingBlacklistFile: String = "compiler/test/dotc/run-test-pickling.blacklist" + def runTestRecheckExcludesFile = "compiler/test/dotc/run-test-recheck.excludes" def runFromTastyBlacklisted: List[String] = loadList(runFromTastyBlacklistFile) def runTestPicklingBlacklisted: List[String] = loadList(runTestPicklingBlacklistFile) + def runTestRecheckExcluded = loadList(runTestRecheckExcludesFile) // load lists diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 88eab8d131e6..7bb546150cd8 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -232,6 +232,14 @@ class CompilationTests { ).checkCompile() } + @Test def recheck: Unit = + given TestGroup = TestGroup("recheck") + aggregateTests( + compileFilesInDir("tests/new", recheckOptions), + compileFilesInDir("tests/pos", recheckOptions, FileFilter.exclude(TestSources.posTestRecheckExcluded)), + compileFilesInDir("tests/run", recheckOptions, FileFilter.exclude(TestSources.runTestRecheckExcluded)) + ).checkCompile() + // Explicit nulls tests @Test def explicitNullsNeg: Unit = { implicit val testGroup: TestGroup = TestGroup("explicitNullsNeg") diff --git a/compiler/test/dotty/tools/vulpix/TestConfiguration.scala b/compiler/test/dotty/tools/vulpix/TestConfiguration.scala index b43dcbdd6046..15fe58510628 100644 --- a/compiler/test/dotty/tools/vulpix/TestConfiguration.scala +++ b/compiler/test/dotty/tools/vulpix/TestConfiguration.scala @@ -81,6 +81,7 @@ object TestConfiguration { ) val picklingWithCompilerOptions = picklingOptions.withClasspath(withCompilerClasspath).withRunClasspath(withCompilerClasspath) + val recheckOptions = defaultOptions.and("-Yrecheck") val scala2CompatMode = defaultOptions.and("-source", "3.0-migration") val explicitUTF8 = defaultOptions and ("-encoding", "UTF8") val explicitUTF16 = defaultOptions and ("-encoding", "UTF16") diff --git a/tests/neg/i6635a.scala b/tests/neg/i6635a.scala new file mode 100644 index 000000000000..a79ea4e7c818 --- /dev/null +++ b/tests/neg/i6635a.scala @@ -0,0 +1,19 @@ +object Test { + abstract class ExprBase { s => + type A + } + + abstract class Lit extends ExprBase { s => + type A = Int + val n: A + } + + // It would be nice if the following could typecheck. We'd need to apply + // a reasoning like this: + // + // Since there is an argument `e2` of type `Lit & e1.type`, it follows that + // e1.type == e2.type Hence, e1.A == e2.A == Int. This looks similar + // to techniques used in GADTs. + // + def castTestFail2a(e1: ExprBase)(e2: Lit & e1.type)(x: e1.A): Int = x // error: Found: (x : e1.A) Required: Int +} diff --git a/tests/pos/i6635.scala b/tests/pos/i6635.scala index dacd1ef5cd8b..406eee6251e6 100644 --- a/tests/pos/i6635.scala +++ b/tests/pos/i6635.scala @@ -27,11 +27,12 @@ object Test { def castTest5a(e1: ExprBase)(e2: LitU with e1.type)(x: e2.A): e1.A = x def castTest5b(e1: ExprBase)(e2: LitL with e1.type)(x: e2.A): e1.A = x - //fail: def castTestFail1(e1: ExprBase)(e2: Lit with e1.type)(x: e2.A): e1.A = x // this is like castTest5a/b, but with Lit instead of LitU/LitL - // the other direction never works: - def castTestFail2a(e1: ExprBase)(e2: Lit with e1.type)(x: e1.A): e2.A = x + + // The next example fails rechecking. It is repeated in i6635a.scala + // def castTestFail2a(e1: ExprBase)(e2: Lit with e1.type)(x: e1.A): e2.A = x def castTestFail2b(e1: ExprBase)(e2: LitL with e1.type)(x: e1.A): e2.A = x + def castTestFail2c(e1: ExprBase)(e2: LitU with e1.type)(x: e1.A): e2.A = x // the problem isn't about order of intersections. diff --git a/tests/pos/i6635a.scala b/tests/pos/i6635a.scala new file mode 100644 index 000000000000..9454e03e3a4a --- /dev/null +++ b/tests/pos/i6635a.scala @@ -0,0 +1,14 @@ +object Test { + abstract class ExprBase { s => + type A + } + + abstract class Lit extends ExprBase { s => + type A = Int + val n: A + } + + // Fails recheck since the result type e2.A is converted to Int to avoid + // a false dependency on e2. + def castTestFail2a(e1: ExprBase)(e2: Lit with e1.type)(x: e1.A): e2.A = x +} From 4b88b5a67424bac7faa6f1fe7f123c77fc84e2a7 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 17 Aug 2021 13:02:09 +0200 Subject: [PATCH 02/24] Cleanups --- compiler/src/dotty/tools/dotc/Compiler.scala | 4 ++-- compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow | 0 2 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index d5d77929d59d..fdbf67107965 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -103,8 +103,8 @@ class Compiler { new TupleOptimizations, // Optimize generic operations on tuples new LetOverApply, // Lift blocks from receivers of applications new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. - List(new PreRecheck) :: - List(new TestRecheck) :: + List(new PreRecheck) :: // Preparations for recheck phase, enabled under -Yrecheck + List(new TestRecheck) :: // Test rechecking, enabled under -Yrecheck List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks diff --git a/compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow b/compiler/src/dotty/tools/dotc/typer/RefineTypes.overflow deleted file mode 100644 index e69de29bb2d1..000000000000 From 5e4e4f0ad9575539f5aac4fcb6193ec7da0f7dfc Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 29 Sep 2021 17:46:49 +0200 Subject: [PATCH 03/24] First version of capture checker. A squashed version of the following commits: Handle byname parameters Don't force symbol completion when printing flags or annotations Check overrides and disallow non-local inferred capture types Handle `this` in capture sets Print capture variable dependencies under -Ydebug-cc Avoid spurious error message Avoid spurious error message "cannot be tracked since its capture set is empty". This arose in lazyref.scala for a DependentTypeTree in an anaonymois function. Dependent type trees map to normal TypeTrees, not InferredTypeTrees (and things go wrong if we try to change that). Drop TopType Consider bounds of type variables to be boxed More tests Avoid multiple maps when creating symbol infos Use a single BiTypeMap to map from inferred result and parameters to method info. This improves efficiency and debuggability by reducing the frequence of multiple stacked maps capture sets. Refactor with CompareResult#andAlso Refactoring: use isOK on CompareResult Reflect inferred parameter types in enclosing method type The variables in the inferred parameter type of an anonymous function need to also show up in the closure type itself, so that they can be constrained. Don't interpolate parameters of anonymous functions Here, we should wait until we get the info from the outside, which can be arbitrarily much later. Compute upper approximation of bimapped sets from both sides Fail when trying to add new elements to mapped sets It's the safe option. Print full origin trail of derived capture sets under -Ycc-debug Fix isEmpty condition in well-formedness check Make printing capture sets dependent on -Ycc-debug Recursion brake for upperApprox Fixes to upperApprox Make instantiteRT a BiTypeMap Otherwise we will not be able to do upper approximations of parameters. Interpolate only variables at negative polarity Interpolating covariant variables risks restricting capture sets to early. For instance, when a variable has the capture set of a called function in its capture set. When we have indirectly recursive calls it could be that the capture set of a called function is not yet fully formed. Interpolate type variables when symbols are completed Allow for possibility that variables are constant Only recomplete symbols if their info changes Add completions to Rechecker Complete val and def definitions lazily on first access. Now, recheckDefDef and recheckValDef are called the first time the new info of the defined symbol is needed, or, if the info is never needed, when the typer gets to the definitions. This only applied to definitions with inferred types. The others are handled in typer sequence, as before. The motivation of the change is that some modifications to inferred types of symbols can be made in subclasses without running into ordering problems. More fixes for subCapture New setting -Ycc-debug for more info on capture variables Fix subCapture in frozen state Previously, we still OKed two empty variables to be compared with subcapture in the frozen state. This should give an error. Direct comparisons of dependent function types Revert: Special treatment of dependent functions in TypeComparer change test Also treat explicit capturing type arguments as boxed Print subcapturing steps in -explain traces Don't decorate type variables with additional capture sets Boxed CapturingTypes Drop unsound capture suppression if expected type is boxed If expected type is boxed, the expression still contributes to the captured variables of its environment. Re-infer result types of anonymous functions Keep erased implicit args Special treatment of dependent functions in TypeComparer Fix addFunctionRefinements Always print refined function types as dependent functions. Makes it easier to see what goes on. Make CaptureSet ++ and ** simplify more Refine function types when reinferring so that they can be dependent Fix avoidance problem when typing blocks We should not pass en expected type when rechecking the expression of a block since that can add local references to global capture set variables. Also: tests for lists and pairs Print empty variables with "?" Fix printing untyped annotations Fix printing annotations in trees Drop redundant code Refactor map operations on capture sets Intoduce Bi-Mapped CaptureSets Report an error is a simply mapped capture set gets new elements that do not come from the original souurce. Introduce a new abstraction of bi-mapped sets that accept new elements and propagate them to the original source. Add map operation to SimpleIdentitySet Restrict tracked class parameters to vals Handle local classes and secondary constructors Fix CapturingType precedence when printing First stab at handling classes Bug fixes 1. Fix canBeTracked for TermRefs only TermRefs where prefix is NoPrefix or `this` can be tracked. The others have to be widened. 2. Fix rule for comparing capture refs on the left 3. Be more careful where comparisons are frozen Capture checker for functions --- compiler/src/dotty/tools/dotc/Compiler.scala | 5 +- compiler/src/dotty/tools/dotc/Run.scala | 3 +- .../src/dotty/tools/dotc/ast/Desugar.scala | 5 + compiler/src/dotty/tools/dotc/ast/Trees.scala | 8 +- compiler/src/dotty/tools/dotc/ast/untpd.scala | 11 + .../tools/dotc/cc/CaptureAnnotation.scala | 63 ++ .../src/dotty/tools/dotc/cc/CaptureOps.scala | 82 +++ .../src/dotty/tools/dotc/cc/CaptureSet.scala | 577 ++++++++++++++++++ .../dotty/tools/dotc/cc/CapturingType.scala | 21 + .../src/dotty/tools/dotc/config/Config.scala | 5 + .../dotty/tools/dotc/config/Printers.scala | 1 + .../tools/dotc/config/ScalaSettings.scala | 2 + .../dotty/tools/dotc/core/Annotations.scala | 6 +- .../dotty/tools/dotc/core/Definitions.scala | 17 +- .../tools/dotc/core/OrderingConstraint.scala | 4 + .../src/dotty/tools/dotc/core/Phases.scala | 16 +- .../src/dotty/tools/dotc/core/StdNames.scala | 4 +- .../dotty/tools/dotc/core/Substituters.scala | 6 +- .../tools/dotc/core/SymDenotations.scala | 9 +- .../dotty/tools/dotc/core/TypeComparer.scala | 194 ++++-- .../dotty/tools/dotc/core/TypeErrors.scala | 1 + .../src/dotty/tools/dotc/core/TypeOps.scala | 26 +- .../src/dotty/tools/dotc/core/Types.scala | 223 ++++++- .../src/dotty/tools/dotc/core/Variances.scala | 3 + .../tools/dotc/core/tasty/TreeUnpickler.scala | 10 +- .../dotty/tools/dotc/parsing/Parsers.scala | 34 +- .../src/dotty/tools/dotc/parsing/Tokens.scala | 4 +- .../tools/dotc/printing/PlainPrinter.scala | 30 +- .../dotty/tools/dotc/printing/Printer.scala | 8 +- .../tools/dotc/printing/RefinedPrinter.scala | 23 +- .../dotty/tools/dotc/reporting/messages.scala | 1 - .../src/dotty/tools/dotc/sbt/ExtractAPI.scala | 1 + .../tools/dotc/transform/EmptyPhase.scala | 19 + .../dotty/tools/dotc/transform/Recheck.scala | 238 ++++++-- .../tools/dotc/transform/TreeChecker.scala | 4 +- .../dotc/transform/TryCatchPatterns.scala | 2 +- .../tools/dotc/transform/TypeTestsCasts.scala | 4 +- .../tools/dotc/typer/CheckCaptures.scala | 468 ++++++++++++++ .../src/dotty/tools/dotc/typer/Checking.scala | 6 +- .../dotty/tools/dotc/typer/Inferencing.scala | 8 +- .../dotty/tools/dotc/typer/RefChecks.scala | 2 +- .../dotty/tools/dotc/typer/TypeAssigner.scala | 26 +- .../src/dotty/tools/dotc/typer/Typer.scala | 5 +- .../tools/dotc/util/SimpleIdentitySet.scala | 13 + .../dotty/tools/dotc/CompilationTests.scala | 3 + library/src-bootstrapped/scala/Retains.scala | 6 + .../scala/annotation/ability.scala | 9 + .../scala/runtime/stdLibPatches/Predef.scala | 1 + .../neg-custom-args/captures/capt-wf.scala | 19 + .../neg-custom-args/captures/try2.check | 38 ++ .../neg-custom-args/captures/try2.scala | 55 ++ tests/disabled/pos/lazylist.scala | 51 ++ .../allow-deep-subtypes}/i9325.scala | 0 tests/neg-custom-args/capt-wf.scala | 35 ++ tests/neg-custom-args/captures/bounded.scala | 14 + tests/neg-custom-args/captures/boxmap.check | 7 + tests/neg-custom-args/captures/boxmap.scala | 14 + tests/neg-custom-args/captures/byname.scala | 10 + .../captures/capt-box-env.scala | 12 + tests/neg-custom-args/captures/capt-box.scala | 13 + .../captures/capt-depfun.scala | 7 + .../captures/capt-depfun2.scala | 10 + tests/neg-custom-args/captures/capt-env.scala | 13 + .../neg-custom-args/captures/capt-test.scala | 26 + .../captures/capt-wf-typer.scala | 10 + tests/neg-custom-args/captures/capt1.check | 46 ++ tests/neg-custom-args/captures/capt1.scala | 34 ++ tests/neg-custom-args/captures/capt2.scala | 9 + tests/neg-custom-args/captures/capt3.scala | 26 + tests/neg-custom-args/captures/cc1.scala | 4 + tests/neg-custom-args/captures/classes.scala | 12 + tests/neg-custom-args/captures/io.scala | 22 + tests/neg-custom-args/captures/lazylist.check | 42 ++ tests/neg-custom-args/captures/lazylist.scala | 41 ++ tests/neg-custom-args/captures/lazyref.check | 28 + tests/neg-custom-args/captures/lazyref.scala | 25 + tests/neg-custom-args/captures/try.check | 25 + tests/neg-custom-args/captures/try.scala | 53 ++ tests/neg-custom-args/captures/try3.scala | 27 + tests/neg/multiLineOps.scala | 2 +- tests/neg/polymorphic-functions1.check | 7 + tests/neg/polymorphic-functions1.scala | 1 + tests/pos-custom-args/captures/bounded.scala | 14 + .../captures/boxmap-paper.scala | 38 ++ tests/pos-custom-args/captures/boxmap.scala | 20 + tests/pos-custom-args/captures/byname.scala | 10 + .../captures/capt-depfun.scala | 18 + .../captures/capt-depfun2.scala | 8 + .../pos-custom-args/captures/capt-test.scala | 35 ++ tests/pos-custom-args/captures/capt0.scala | 7 + tests/pos-custom-args/captures/capt1.scala | 27 + tests/pos-custom-args/captures/capt2.scala | 20 + .../pos-custom-args/captures/cc-expand.scala | 21 + tests/pos-custom-args/captures/classes.scala | 34 ++ .../pos-custom-args/captures/iterators.scala | 23 + tests/pos-custom-args/captures/lazyref.scala | 25 + .../captures/list-encoding.scala | 23 + tests/pos-custom-args/captures/lists.scala | 91 +++ tests/pos-custom-args/captures/pairs.scala | 33 + tests/pos-custom-args/captures/try.scala | 26 + tests/pos-custom-args/captures/try3.scala | 51 ++ 101 files changed, 3251 insertions(+), 228 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureOps.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureSet.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CapturingType.scala create mode 100644 compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala create mode 100644 compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala create mode 100644 library/src-bootstrapped/scala/Retains.scala create mode 100644 library/src-bootstrapped/scala/annotation/ability.scala create mode 100644 tests/disabled/neg-custom-args/captures/capt-wf.scala create mode 100644 tests/disabled/neg-custom-args/captures/try2.check create mode 100644 tests/disabled/neg-custom-args/captures/try2.scala create mode 100644 tests/disabled/pos/lazylist.scala rename tests/{neg => neg-custom-args/allow-deep-subtypes}/i9325.scala (100%) create mode 100644 tests/neg-custom-args/capt-wf.scala create mode 100644 tests/neg-custom-args/captures/bounded.scala create mode 100644 tests/neg-custom-args/captures/boxmap.check create mode 100644 tests/neg-custom-args/captures/boxmap.scala create mode 100644 tests/neg-custom-args/captures/byname.scala create mode 100644 tests/neg-custom-args/captures/capt-box-env.scala create mode 100644 tests/neg-custom-args/captures/capt-box.scala create mode 100644 tests/neg-custom-args/captures/capt-depfun.scala create mode 100644 tests/neg-custom-args/captures/capt-depfun2.scala create mode 100644 tests/neg-custom-args/captures/capt-env.scala create mode 100644 tests/neg-custom-args/captures/capt-test.scala create mode 100644 tests/neg-custom-args/captures/capt-wf-typer.scala create mode 100644 tests/neg-custom-args/captures/capt1.check create mode 100644 tests/neg-custom-args/captures/capt1.scala create mode 100644 tests/neg-custom-args/captures/capt2.scala create mode 100644 tests/neg-custom-args/captures/capt3.scala create mode 100644 tests/neg-custom-args/captures/cc1.scala create mode 100644 tests/neg-custom-args/captures/classes.scala create mode 100644 tests/neg-custom-args/captures/io.scala create mode 100644 tests/neg-custom-args/captures/lazylist.check create mode 100644 tests/neg-custom-args/captures/lazylist.scala create mode 100644 tests/neg-custom-args/captures/lazyref.check create mode 100644 tests/neg-custom-args/captures/lazyref.scala create mode 100644 tests/neg-custom-args/captures/try.check create mode 100644 tests/neg-custom-args/captures/try.scala create mode 100644 tests/neg-custom-args/captures/try3.scala create mode 100644 tests/neg/polymorphic-functions1.check create mode 100644 tests/neg/polymorphic-functions1.scala create mode 100644 tests/pos-custom-args/captures/bounded.scala create mode 100644 tests/pos-custom-args/captures/boxmap-paper.scala create mode 100644 tests/pos-custom-args/captures/boxmap.scala create mode 100644 tests/pos-custom-args/captures/byname.scala create mode 100644 tests/pos-custom-args/captures/capt-depfun.scala create mode 100644 tests/pos-custom-args/captures/capt-depfun2.scala create mode 100644 tests/pos-custom-args/captures/capt-test.scala create mode 100644 tests/pos-custom-args/captures/capt0.scala create mode 100644 tests/pos-custom-args/captures/capt1.scala create mode 100644 tests/pos-custom-args/captures/capt2.scala create mode 100644 tests/pos-custom-args/captures/cc-expand.scala create mode 100644 tests/pos-custom-args/captures/classes.scala create mode 100644 tests/pos-custom-args/captures/iterators.scala create mode 100644 tests/pos-custom-args/captures/lazyref.scala create mode 100644 tests/pos-custom-args/captures/list-encoding.scala create mode 100644 tests/pos-custom-args/captures/lists.scala create mode 100644 tests/pos-custom-args/captures/pairs.scala create mode 100644 tests/pos-custom-args/captures/try.scala create mode 100644 tests/pos-custom-args/captures/try3.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index fdbf67107965..c717d00a9a07 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -4,6 +4,7 @@ package dotc import core._ import Contexts._ import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import parsing.Parser import Phases.Phase import transform._ @@ -81,6 +82,8 @@ class Compiler { new SpecializeApplyMethods, // Adds specialized methods to FunctionN new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher) :: // Compile pattern matches + List(new PreRecheck) :: // Preparations for check captures phase, enabled under -Ycc + List(new CheckCaptures) :: // Check captures, enabled under -Ycc List(new ElimOpaque, // Turn opaque into normal aliases new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only) new ExplicitOuter, // Add accessors to outer classes from nested ones. @@ -103,8 +106,6 @@ class Compiler { new TupleOptimizations, // Optimize generic operations on tuples new LetOverApply, // Lift blocks from receivers of applications new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. - List(new PreRecheck) :: // Preparations for recheck phase, enabled under -Yrecheck - List(new TestRecheck) :: // Test rechecking, enabled under -Yrecheck List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index 9f7036f64255..7b32f66c4f6e 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -20,7 +20,6 @@ import reporting.{Reporter, Suppression, Action} import reporting.Diagnostic import reporting.Diagnostic.Warning import rewrites.Rewrites - import profile.Profiler import printing.XprintMode import typer.ImplicitRunInfo @@ -320,7 +319,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint val fusedPhase = ctx.base.fusedContaining(prevPhase) val echoHeader = f"[[syntax trees at end of $fusedPhase%25s]] // ${unit.source}" val tree = if ctx.isAfterTyper then unit.tpdTree else unit.untpdTree - val treeString = tree.show(using ctx.withProperty(XprintMode, Some(()))) + val treeString = fusedPhase.show(tree) last match { case SomePrintedTree(phase, lastTreeString) if lastTreeString == treeString => diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 920871210eee..38030955b776 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1733,6 +1733,9 @@ object desugar { flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) case ext: ExtMethods => Block(List(ext), Literal(Constant(())).withSpan(ext.span)) + case CapturingTypeTree(refs, parent) => + val annot = New(scalaDot(tpnme.retains), List(refs)) + Annotated(parent, annot) } desugared.withSpan(tree.span) } @@ -1871,6 +1874,8 @@ object desugar { case _ => traverseChildren(tree) } }.traverse(expr) + case CapturingTypeTree(refs, parent) => + collect(parent) case _ => } collect(tree) diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 7aa4491c31de..d4b7eff1465b 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -253,16 +253,10 @@ object Trees { /** Tree's denotation can be derived from its type */ abstract class DenotingTree[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends Tree[T] { type ThisTree[-T >: Untyped] <: DenotingTree[T] - override def denot(using Context): Denotation = typeOpt match { + override def denot(using Context): Denotation = typeOpt.stripped match case tpe: NamedType => tpe.denot case tpe: ThisType => tpe.cls.denot - case tpe: AnnotatedType => tpe.stripAnnots match { - case tpe: NamedType => tpe.denot - case tpe: ThisType => tpe.cls.denot - case _ => NoDenotation - } case _ => NoDenotation - } } /** Tree's denot/isType/isTerm properties come from a subtree diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 7e00972f354d..b9960cbb4652 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -147,6 +147,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case Floating } + /** {x1, ..., xN} T (only relevant under -Ycc) */ + case class CapturingTypeTree(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** Short-lived usage in typer, does not need copy/transform/fold infrastructure */ case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree @@ -650,6 +653,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: Number if (digits == tree.digits) && (kind == tree.kind) => tree case _ => finalize(tree, untpd.Number(digits, kind)) } + def CapturingTypeTree(tree: Tree)(refs: List[Tree], parent: Tree)(using Context): Tree = tree match + case tree: CapturingTypeTree if (refs eq tree.refs) && (parent eq tree.parent) => tree + case _ => finalize(tree, untpd.CapturingTypeTree(refs, parent)) + def TypedSplice(tree: Tree)(splice: tpd.Tree)(using Context): ProxyTree = tree match { case tree: TypedSplice if splice `eq` tree.splice => tree case _ => finalize(tree, untpd.TypedSplice(splice)(using ctx)) @@ -715,6 +722,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { tree case MacroTree(expr) => cpy.MacroTree(tree)(transform(expr)) + case CapturingTypeTree(refs, parent) => + cpy.CapturingTypeTree(tree)(transform(refs), transform(parent)) case _ => super.transformMoreCases(tree) } @@ -776,6 +785,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, splice) case MacroTree(expr) => this(x, expr) + case CapturingTypeTree(refs, parent) => + this(this(x, refs), parent) case _ => super.foldMoreCases(x, tree) } diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala new file mode 100644 index 000000000000..5f73b50a6bbe --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala @@ -0,0 +1,63 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.Trees.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import printing.Printer +import printing.Texts.Text + + +case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation: + import CaptureAnnotation.* + import tpd.* + + override def tree(using Context) = + val elems = refs.elems.toList.map { + case cr: TermRef => ref(cr) + case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr) + case cr: ThisType => This(cr.cls) + } + val arg = repeated(elems, TypeTree(defn.AnyType)) + New(symbol.typeRef, arg :: Nil) + + override def symbol(using Context) = defn.RetainsAnnot + + override def derivedAnnotation(tree: Tree)(using Context): Annotation = + unsupported("derivedAnnotation(Tree)") + + def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation = + if (this.refs eq refs) && (this.boxed == boxed) then this + else CaptureAnnotation(refs, boxed) + + override def sameAnnotation(that: Annotation)(using Context): Boolean = that match + case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2 + case _ => false + + override def mapWith(tp: TypeMap)(using Context) = + val elems = refs.elems.toList + val elems1 = elems.mapConserve(tp) + if elems1 eq elems then this + else if elems1.forall(_.isInstanceOf[CaptureRef]) + then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed) + else EmptyAnnotation + + override def refersToParamOf(tl: TermLambda)(using Context): Boolean = + refs.elems.exists { + case TermParamRef(tl1, _) => tl eq tl1 + case _ => false + } + + override def toText(printer: Printer): Text = refs.toText(printer) + + override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0) + + override def eql(that: Annotation) = that match + case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed) + case _ => false + +end CaptureAnnotation diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala new file mode 100644 index 000000000000..09064314b1bf --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -0,0 +1,82 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import util.Property.Key +import tpd.* + +private val Captures: Key[CaptureSet] = Key() +private val IsBoxed: Key[Unit] = Key() + +def retainedElems(tree: Tree)(using Context): List[Tree] = tree match + case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems + case _ => Nil + +extension (tree: Tree) + + def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef] + + def toCaptureSet(using Context): CaptureSet = + tree.getAttachment(Captures) match + case Some(refs) => refs + case None => + val refs = CaptureSet(retainedElems(tree).map(_.toCaptureRef)*) + .showing(i"toCaptureSet $tree --> $result", capt) + tree.putAttachment(Captures, refs) + refs + + def isBoxedCapturing(using Context): Boolean = + tree.hasAttachment(IsBoxed) + + def setBoxedCapturing()(using Context): Unit = + tree.putAttachment(IsBoxed, ()) + +extension (tp: Type) + + def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match + case CapturingType(p, r, b) => + if (parent eq p) && (refs eq r) then tp + else CapturingType(parent, refs, b) + + /** If this is type variable instantiated or upper bounded with a capturing type, + * the capture set associated with that type. Extended to and-or types and + * type proxies in the obvious way. If a term has a type with a boxed captureset, + * that captureset counts towards the capture variables of the envirionment. + */ + def boxedCaptured(using Context): CaptureSet = + def getBoxed(tp: Type): CaptureSet = tp match + case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty + case tp: TypeProxy => getBoxed(tp.superType) + case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2) + case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2) + case _ => CaptureSet.empty + getBoxed(tp) + + def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty + + def canHaveInferredCapture(using Context): Boolean = tp match + case tp: TypeRef if tp.symbol.isClass => + !tp.symbol.isValueClass && tp.symbol != defn.AnyClass + case _: TypeVar | _: TypeParamRef => + false + case tp: TypeProxy => + tp.superType.canHaveInferredCapture + case tp: AndType => + tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture + case tp: OrType => + tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture + case _ => + false + + def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match + case CapturingType(parent, _, _) => + parent.stripCapturing + case atd @ AnnotatedType(parent, annot) => + atd.derivedAnnotatedType(parent.stripCapturing, annot) + case _ => + tp diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala new file mode 100644 index 000000000000..f8ca2f87e3c5 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -0,0 +1,577 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Flags.*, Contexts.*, Decorators.* +import config.Printers.capt +import Annotations.Annotation +import annotation.threadUnsafe +import annotation.constructorOnly +import annotation.internal.sharable +import reporting.trace +import printing.{Showable, Printer} +import printing.Texts.* +import util.{SimpleIdentitySet, Property} +import util.common.alwaysTrue +import scala.collection.mutable + +/** A class for capture sets. Capture sets can be constants or variables. + * Capture sets support inclusion constraints <:< where <:< is subcapturing. + * They also allow mapping with arbitrary functions from elements to capture sets, + * by supporting a monadic flatMap operation. That is, constraints can be + * of one of the following forms + * + * cs1 <:< cs2 + * cs1 = ∪ {f(x) | x ∈ cs2} + * + * where the `f`s are arbitrary functions from capture references to capture sets. + * We call the resulting constraint system "monadic set constraints". + */ +sealed abstract class CaptureSet extends Showable: + import CaptureSet.* + + /** The elements of this capture set. For capture variables, + * the elements known so far. + */ + def elems: Refs + + /** Is this capture set constant (i.e. not an unsolved capture variable)? + * Solved capture variables count as constant. + */ + def isConst: Boolean + + /** Is this capture set always empty? For capture veraiables, returns + * always false + */ + def isAlwaysEmpty: Boolean + + /** Is this capture set definitely non-empty? */ + final def isNotEmpty: Boolean = !elems.isEmpty + + /** Cast to variable. @pre: @isConst */ + def asVar: Var = + assert(!isConst) + asInstanceOf[Var] + + /** Add new elements to this capture set if allowed. + * @pre `newElems` is not empty and does not overlap with `this.elems`. + * Constant capture sets never allow to add new elements. + * Variables allow it if and only if the new elements can be included + * in all their supersets. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if elements were added, or a conflicting + * capture set that prevents addition otherwise. + */ + protected def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult + + /** If this is a variable, add `cs` as a super set */ + protected def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult + + /** If `cs` is a variable, add this capture set as one of its super sets */ + protected def addSub(cs: CaptureSet)(using Context): this.type = + cs.addSuper(this)(using ctx, UnrecordedState) + this + + /** Try to include all references of `elems` that are not yet accounted by this + * capture set. Inclusion is via `addNewElems`. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if all unaccounted elements could be added, + * capture set that prevents addition otherwise. + */ + protected final def tryInclude(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val unaccounted = elems.filter(!accountsFor(_)) + if unaccounted.isEmpty then CompareResult.OK + else addNewElems(unaccounted, origin) + + protected final def tryInclude(elem: CaptureRef, origin: CaptureSet)(using Context, VarState): CompareResult = + if accountsFor(elem) then CompareResult.OK + else addNewElems(elem.singletonCaptureSet.elems, origin) + + extension (x: CaptureRef) private def subsumes(y: CaptureRef) = + (x eq y) + || y.match + case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ? + case _ => false + + /** {x} <:< this where <:< is subcapturing, but treating all variables + * as frozen. + */ + def accountsFor(x: CaptureRef)(using ctx: Context): Boolean = + reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) { + elems.exists(_.subsumes(x)) + || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK + } + + /** The subcapturing test */ + final def subCaptures(that: CaptureSet, frozen: Boolean)(using Context): CompareResult = + subCaptures(that)(using ctx, if frozen then FrozenState else VarState()) + + private def subCaptures(that: CaptureSet)(using Context, VarState): CompareResult = + def recur(elems: List[CaptureRef]): CompareResult = elems match + case elem :: elems1 => + var result = that.tryInclude(elem, this) + if !result.isOK && !elem.isRootCapability && summon[VarState] != FrozenState then + result = elem.captureSetOfInfo.subCaptures(that) + if result.isOK then + recur(elems1) + else + varState.abort() + result + case Nil => + addSuper(that) + recur(elems.toList) + .showing(i"subcaptures $this <:< $that = ${result.show}", capt) + + def =:= (that: CaptureSet)(using Context): Boolean = + this.subCaptures(that, frozen = true).isOK + && that.subCaptures(this, frozen = true).isOK + + /** The smallest capture set (via <:<) that is a superset of both + * `this` and `that` + */ + def ++ (that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then that + else if that.subCaptures(this, frozen = true).isOK then this + else if this.isConst && that.isConst then Const(this.elems ++ that.elems) + else Var(this.elems ++ that.elems).addSub(this).addSub(that) + + /** The smallest superset (via <:<) of this capture set that also contains `ref`. + */ + def + (ref: CaptureRef)(using Context): CaptureSet = + this ++ ref.singletonCaptureSet + + /** The largest capture set (via <:<) that is a subset of both `this` and `that` + */ + def **(that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then this + else if that.subCaptures(this, frozen = true).isOK then that + else if this.isConst && that.isConst then Const(elems.intersect(that.elems)) + else if that.isConst then Intersected(this.asVar, that) + else Intersected(that.asVar, this) + + def -- (that: CaptureSet.Const)(using Context): CaptureSet = + val elems1 = elems.filter(!that.accountsFor(_)) + if elems1.size == elems.size then this + else if this.isConst then Const(elems1) + else Diff(asVar, that) + + def - (ref: CaptureRef)(using Context): CaptureSet = + this -- ref.singletonCaptureSet + + def filter(p: CaptureRef => Boolean)(using Context): CaptureSet = + if this.isConst then Const(elems.filter(p)) + else Filtered(asVar, p) + + /** capture set obtained by applying `f` to all elements of the current capture set + * and joining the results. If the current capture set is a variable, the same + * transformation is applied to all future additions of new elements. + */ + def map(tm: TypeMap)(using Context): CaptureSet = tm match + case tm: BiTypeMap => + val mappedElems = elems.map(tm.forward) + if isConst then Const(mappedElems) + else BiMapped(asVar, tm, mappedElems) + case _ => + val mapped = mapRefs(elems, tm, tm.variance) + if isConst then mapped + else Mapped(asVar, tm, tm.variance, mapped) + + def substParams(tl: BindingType, to: List[Type])(using Context) = + map(Substituters.SubstParamsMap(tl, to)) + + /** An upper approximation of this capture set. This is the set itself + * except for real (non-mapped, non-filtered) capture set variables, where + * it is the intersection of all upper approximations of known supersets + * of the variable. + * The upper approximation is meaningful only if it is constant. If not, + * `upperApprox` can return an arbitrary capture set variable. + */ + protected def upperApprox(origin: CaptureSet)(using Context): CaptureSet + + protected def propagateSolved()(using Context): Unit = () + + def toRetainsTypeArg(using Context): Type = + assert(isConst) + ((NoType: Type) /: elems) ((tp, ref) => + if tp.exists then OrType(tp, ref, soft = false) else ref) + + def toRegularAnnotation(using Context): Annotation = + Annotation(CaptureAnnotation(this, boxed = false).tree) + + override def toText(printer: Printer): Text = + Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") + +object CaptureSet: + type Refs = SimpleIdentitySet[CaptureRef] + type Vars = SimpleIdentitySet[Var] + type Deps = SimpleIdentitySet[CaptureSet] + + /** If set to `true`, capture stack traces that tell us where sets are created */ + private final val debugSets = false + + private val emptySet = SimpleIdentitySet.empty + @sharable private var varId = 0 + + val empty: CaptureSet.Const = Const(emptySet) + + /** The universal capture set `{*}` */ + def universal(using Context): CaptureSet = + defn.captureRoot.termRef.singletonCaptureSet + + /** Used as a recursion brake */ + @sharable private[dotc] val Pending = Const(SimpleIdentitySet.empty) + + def apply(elems: CaptureRef*)(using Context): CaptureSet.Const = + if elems.isEmpty then empty + else Const(SimpleIdentitySet(elems.map(_.normalizedRef)*)) + + def apply(elems: Refs)(using Context): CaptureSet.Const = + if elems.isEmpty then empty else Const(elems) + + class Const private[CaptureSet] (val elems: Refs) extends CaptureSet: + assert(elems != null) + def isConst = true + def isAlwaysEmpty = elems.isEmpty + + def addNewElems(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState) = CompareResult.OK + + def upperApprox(origin: CaptureSet)(using Context): CaptureSet = this + + override def toString = elems.toString + end Const + + class Var(initialElems: Refs = emptySet) extends CaptureSet: + val id = + varId += 1 + varId + + private var isSolved: Boolean = false + + var elems: Refs = initialElems + var deps: Deps = emptySet + def isConst = isSolved + def isAlwaysEmpty = false + + private def recordElemsState()(using VarState): Boolean = + varState.getElems(this) match + case None => varState.putElems(this, elems) + case _ => true + + private[CaptureSet] def recordDepsState()(using VarState): Boolean = + varState.getDeps(this) match + case None => varState.putDeps(this, deps) + case _ => true + + def resetElems()(using state: VarState): Unit = + elems = state.elems(this) + + def resetDeps()(using state: VarState): Unit = + deps = state.deps(this) + + def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if !isConst && recordElemsState() then + elems ++= newElems + // assert(id != 2 || elems.size != 2, this) + (CompareResult.OK /: deps) { (r, dep) => + r.andAlso(dep.tryInclude(newElems, this)) + } + else + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult = + if (cs eq this) || cs.elems.contains(defn.captureRoot.termRef) || isConst then + CompareResult.OK + else if recordDepsState() then + deps += cs + CompareResult.OK + else + CompareResult.fail(this) + + private var computingApprox = false + + final def upperApprox(origin: CaptureSet)(using Context): CaptureSet = + if computingApprox then universal + else if isConst then this + else + computingApprox = true + try computeApprox(origin).ensuring(_.isConst) + finally computingApprox = false + + protected def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + (universal /: deps) { (acc, sup) => acc ** sup.upperApprox(this) } + + def solve(variance: Int)(using Context): Unit = + if variance < 0 && !isConst then + val approx = upperApprox(empty) + //println(i"solving var $this $approx ${approx.isConst} deps = ${deps.toList}") + if approx.isConst then + val newElems = approx.elems -- elems + if newElems.isEmpty || addNewElems(newElems, empty)(using ctx, VarState()).isOK then + markSolved() + + def markSolved()(using Context): Unit = + isSolved = true + deps.foreach(_.propagateSolved()) + + protected def ids(using Context): String = + val trail = this.match + case dv: DerivedVar => dv.source.ids + case _ => "" + s"$id${getClass.getSimpleName.take(1)}$trail" + + override def toText(printer: Printer): Text = inContext(printer.printerContext) { + for vars <- ctx.property(ShownVars) do vars += this + super.toText(printer) ~ (Str(ids) provided !isConst && ctx.settings.YccDebug.value) + } + + override def toString = s"Var$id$elems" + end Var + + abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context) + extends Var(initialElems): + def source: Var + + addSub(source) + + override def propagateSolved()(using Context) = + if source.isConst && !isConst then markSolved() + end DerivedVar + + /** A variable that changes when `source` changes, where all additional new elements are mapped + * using ∪ { f(x) | x <- elems } + */ + class Mapped private[CaptureSet] + (val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context) + extends DerivedVar(initial.elems): + addSub(initial) + val stack = if debugSets then (new Throwable).getStackTrace().take(20) else null + + private def whereCreated(using Context): String = + if stack == null then "" + else i""" + |Stack trace of variable creation:" + |${stack.mkString("\n")}""" + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val added = + if origin eq source then + mapRefs(newElems, tm, variance) + else + if variance <= 0 && !origin.isConst && (origin ne initial) then + report.warning(i"trying to add elems $newElems from unrecognized source $origin of mapped set $this$whereCreated") + return CompareResult.fail(this) + Const(newElems) + super.addNewElems(added.elems, origin) + .andAlso { + if added.isConst then CompareResult.OK + else if added.asVar.recordDepsState() then { addSub(added); CompareResult.OK } + else CompareResult.fail(this) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).map(tm) + + override def propagateSolved()(using Context) = + if initial.isConst then super.propagateSolved() + + override def toString = s"Mapped$id($source, elems = $elems)" + end Mapped + + class BiMapped private[CaptureSet] + (val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) + extends DerivedVar(initialElems): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if origin eq source then + super.addNewElems(newElems.map(bimap.forward), origin) + else + super.addNewElems(newElems, origin) + .andAlso { + source.tryInclude(newElems.map(bimap.backward), this) + .showing(i"propagating new elems $newElems backward from $this to $source", capt) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + val supApprox = super.computeApprox(this) + if source eq origin then supApprox.map(bimap.inverseTypeMap) + else source.upperApprox(this).map(bimap) ** supApprox + + override def toString = s"BiMapped$id($source, elems = $elems)" + end BiMapped + + /** A variable with elements given at any time as { x <- source.elems | p(x) } */ + class Filtered private[CaptureSet] + (val source: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context) + extends DerivedVar(source.elems.filter(p)): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + super.addNewElems(newElems.filter(p), origin) + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).filter(p) + + override def toString = s"${getClass.getSimpleName}$id($source, elems = $elems)" + end Filtered + + /** A variable with elements given at any time as { x <- source.elems | !other.accountsFor(x) } */ + class Diff(source: Var, other: Const)(using Context) + extends Filtered(source, !other.accountsFor(_)) + + /** A variable with elements given at any time as { x <- source.elems | other.accountsFor(x) } */ + class Intersected(source: Var, other: CaptureSet)(using Context) + extends Filtered(source, other.accountsFor(_)): + addSub(other) + + def extrapolateCaptureRef(r: CaptureRef, tm: TypeMap, variance: Int)(using Context): CaptureSet = + val r1 = tm(r) + val upper = r1.captureSet + def isExact = + upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1) + if variance > 0 || isExact then upper + else if variance < 0 then CaptureSet.empty + else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting") + + def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet = + ((empty: CaptureSet) /: xs)((cs, x) => cs ++ f(x)) + + def mapRefs(xs: Refs, tm: TypeMap, variance: Int)(using Context): CaptureSet = + mapRefs(xs, extrapolateCaptureRef(_, tm, variance)) + + type CompareResult = CompareResult.Type + + /** None = ok, Some(cs) = failure since not a subset of cs */ + object CompareResult: + opaque type Type = CaptureSet + val OK: Type = Const(emptySet) + def fail(cs: CaptureSet): Type = cs + extension (result: Type) + def isOK: Boolean = result eq OK + def blocking: CaptureSet = result + def show: String = if result.isOK then "OK" else result.toString + def andAlso(op: Context ?=> Type)(using Context): Type = if result.isOK then op else result + + class VarState: + private val elemsMap: util.EqHashMap[Var, Refs] = new util.EqHashMap + private val depsMap: util.EqHashMap[Var, Deps] = new util.EqHashMap + + def elems(v: Var): Refs = elemsMap(v) + def getElems(v: Var): Option[Refs] = elemsMap.get(v) + def putElems(v: Var, elems: Refs): Boolean = { elemsMap(v) = elems; true } + + def deps(v: Var): Deps = depsMap(v) + def getDeps(v: Var): Option[Deps] = depsMap.get(v) + def putDeps(v: Var, deps: Deps): Boolean = { depsMap(v) = deps; true } + + def abort(): Unit = + elemsMap.keysIterator.foreach(_.resetElems()(using this)) + depsMap.keysIterator.foreach(_.resetDeps()(using this)) + end VarState + + @sharable + object FrozenState extends VarState: + override def putElems(v: Var, refs: Refs) = false + override def putDeps(v: Var, deps: Deps) = false + override def abort(): Unit = () + + @sharable + object UnrecordedState extends VarState: + override def putElems(v: Var, refs: Refs) = true + override def putDeps(v: Var, deps: Deps) = true + override def abort(): Unit = () + + def varState(using state: VarState): VarState = state + + def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet = + CaptureSet.empty + /* + def captureSetOf(tp: Type): CaptureSet = tp match + case tp: TypeRef if tp.symbol.is(ParamAccessor) => + def mapArg(accs: List[Symbol], tps: List[Type]): CaptureSet = accs match + case acc :: accs1 if tps.nonEmpty => + if acc == tp.symbol then tps.head.captureSet + else mapArg(accs1, tps.tail) + case _ => + empty + mapArg(cinfo.cls.paramAccessors, argTypes) + case _ => + tp.captureSet + val css = + for + parent <- cinfo.parents if parent.classSymbol == defn.RetainingClass + arg <- parent.argInfos + yield captureSetOf(arg) + css.foldLeft(empty)(_ ++ _) + */ + def ofInfo(ref: CaptureRef)(using Context): CaptureSet = ref match + case ref: ThisType => + val declaredCaptures = ref.cls.givenSelfType.captureSet + ref.cls.paramAccessors.foldLeft(declaredCaptures) ((cs, acc) => + cs ++ acc.termRef.captureSetOfInfo) // ^^^ need to also include outer references of inner classes + .showing(i"cc info $ref with ${ref.cls.paramAccessors.map(_.termRef)}%, % = $result", capt) + case ref: TermRef if ref.isRootCapability => ref.singletonCaptureSet + case _ => ofType(ref.underlying) + + def ofType(tp: Type)(using Context): CaptureSet = + def recur(tp: Type): CaptureSet = tp.dealias match + case tp: TermRef => + tp.captureSet + case tp: TermParamRef => + tp.captureSet + case _: TypeRef | _: TypeParamRef => + empty + case CapturingType(parent, refs, _) => + recur(parent) ++ refs + case AppliedType(tycon, args) => + val cs = recur(tycon) + tycon.typeParams match + case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args) + case _ => cs + case tp: TypeProxy => + recur(tp.underlying) + case AndType(tp1, tp2) => + recur(tp1) ** recur(tp2) + case OrType(tp1, tp2) => + recur(tp1) ++ recur(tp2) + case tp: ClassInfo => + ofClass(tp, Nil) + case _ => + empty + recur(tp) + .showing(i"capture set of $tp = $result", capt) + + private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key() + + def withCaptureSetsExplained[T](op: Context ?=> T)(using ctx: Context): T = + if ctx.settings.YccDebug.value then + val shownVars = mutable.Set[Var]() + inContext(ctx.withProperty(ShownVars, Some(shownVars))) { + try op + finally + val reachable = mutable.Set[Var]() + val todo = mutable.Queue[Var]() ++= shownVars + def incl(cv: Var): Unit = + if !reachable.contains(cv) then todo += cv + while todo.nonEmpty do + val cv = todo.dequeue() + if !reachable.contains(cv) then + reachable += cv + cv.deps.foreach { + case cv: Var => incl(cv) + case _ => + } + cv match + case cv: DerivedVar => incl(cv.source) + case _ => + val allVars = reachable.toArray.sortBy(_.id) + println(i"Capture set dependencies:") + for cv <- allVars do + println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %") + } + else op +end CaptureSet diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala new file mode 100644 index 000000000000..2eeb1ff41b72 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -0,0 +1,21 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.* + +object CapturingType: + + def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type = + if refs.isAlwaysEmpty then parent + else AnnotatedType(parent, CaptureAnnotation(refs, boxed)) + + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = + if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then + tp.annot match + case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed)) + case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + else None + +end CapturingType diff --git a/compiler/src/dotty/tools/dotc/config/Config.scala b/compiler/src/dotty/tools/dotc/config/Config.scala index ac1708378e73..a54987b23ecc 100644 --- a/compiler/src/dotty/tools/dotc/config/Config.scala +++ b/compiler/src/dotty/tools/dotc/config/Config.scala @@ -227,4 +227,9 @@ object Config { * reduces the number of allocated denotations by ~50%. */ inline val reuseSymDenotations = true + + /** If true, print capturing types in the form `{c} T`. + * If false, print them in the form `T @retains(c)`. + */ + inline val printCaptureSetsAsPrefix = true } diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index b71e1e7f188a..d20d482b062e 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -12,6 +12,7 @@ object Printers { val default = new Printer + val capt = noPrinter val constr = noPrinter val core = noPrinter val checks = noPrinter diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 627c027bfd34..4499e090a212 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -312,6 +312,8 @@ private sealed trait YSettings: val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation") val YscalaRelease: Setting[String] = ChoiceSetting("-Yscala-release", "release", "Emit TASTy files that can be consumed by specified version of the compiler. The compilation will fail if for any reason valid TASTy cannot be produced (e.g. the code contains references to some parts of the standard library API that are missing in the older stdlib or uses language features unexpressible in the older version of TASTy format)", ScalaSettings.supportedScalaReleaseVersions, "", aliases = List("--Yscala-release")) val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)") + val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references") + val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Debug info for captured references") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index b8d62210ce26..d0172c82972c 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -48,7 +48,7 @@ object Annotations { /** The tree evaluation has finished. */ def isEvaluated: Boolean = true - /** Normally, type map over all tree nodes of this annotation, but can + /** Normally, applies a type map to all tree nodes of this annotation, but can * be overridden. Returns EmptyAnnotation if type type map produces a range * type, since ranges cannot be types of trees. */ @@ -86,6 +86,10 @@ object Annotations { def sameAnnotation(that: Annotation)(using Context): Boolean = symbol == that.symbol && tree.sameTree(that.tree) + + /** Operations for hash-consing, can be overridden */ + def hash: Int = System.identityHashCode(this) + def eql(that: Annotation) = this eq that } case class ConcreteAnnotation(t: Tree) extends Annotation: diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 794119cd7a79..689a81ab7a32 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -14,6 +14,7 @@ import typer.ImportInfo.RootRef import Comments.CommentsContext import Comments.Comment import util.Spans.NoSpan +import cc.{CapturingType, CaptureSet} import scala.annotation.tailrec @@ -146,11 +147,13 @@ class Definitions { private def enterMethod(cls: ClassSymbol, name: TermName, info: Type, flags: FlagSet = EmptyFlags): TermSymbol = newMethod(cls, name, info, flags).entered - private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = { - val sym = newPermanentSymbol(ScalaPackageClass, name, flags, TypeAlias(tpe)) + private def enterPermanentSymbol(name: Name, info: Type, flags: FlagSet = EmptyFlags): Symbol = + val sym = newPermanentSymbol(ScalaPackageClass, name, flags, info) ScalaPackageClass.currentPackageDecls.enter(sym) sym - } + + private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = + enterPermanentSymbol(name, TypeAlias(tpe), flags).asType private def enterBinaryAlias(name: TypeName, op: (Type, Type) => Type): TypeSymbol = enterAliasType(name, @@ -446,6 +449,7 @@ class Definitions { @tu lazy val andType: TypeSymbol = enterBinaryAlias(tpnme.AND, AndType(_, _)) @tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _, soft = false)) + @tu lazy val captureRoot: TermSymbol = enterPermanentSymbol(nme.CAPTURE_ROOT, AnyType).asTerm /** Method representing a throw */ @tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw, @@ -943,6 +947,8 @@ class Definitions { @tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName") @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") @tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since") + @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") + @tu lazy val AbilityAnnot: ClassSymbol = requiredClass("scala.annotation.ability") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") @@ -1535,6 +1541,9 @@ class Definitions { def isFunctionType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) + def isFunctionOrPolyType(tp: RefinedType)(using Context): Boolean = + isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass) + // Specialized type parameters defined for scala.Function{0,1,2}. @tu lazy val Function1SpecializedParamTypes: collection.Set[TypeRef] = Set(IntType, LongType, FloatType, DoubleType) @@ -1835,7 +1844,7 @@ class Definitions { this.initCtx = ctx if (!isInitialized) { // force initialization of every symbol that is synthesized or hijacked by the compiler - val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() :+ JavaEnumClass + val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() ++ List(JavaEnumClass, captureRoot) isInitialized = true } diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 1f83224cc3e7..85598e79ad9b 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -12,6 +12,7 @@ import config.Printers.constr import reflect.ClassTag import annotation.tailrec import annotation.internal.sharable +import cc.{CapturingType, derivedCapturingType} object OrderingConstraint { @@ -330,6 +331,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, case tp: TypeVar => val underlying1 = recur(tp.underlying, fromBelow) if underlying1 ne tp.underlying then underlying1 else tp + case CapturingType(parent, refs, _) => + val parent1 = recur(parent, fromBelow) + if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp case tp: AnnotatedType => val parent1 = recur(tp.parent, fromBelow) if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 63a70ad73a83..3eae27b3b4ba 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -13,10 +13,12 @@ import scala.collection.mutable.ListBuffer import dotty.tools.dotc.transform.MegaPhase._ import dotty.tools.dotc.transform._ import Periods._ -import parsing.{ Parser} +import parsing.Parser +import printing.XprintMode import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import typer.ImportInfo.withRootImports -import ast.tpd +import ast.{tpd, untpd} import scala.annotation.internal.sharable import scala.util.control.NonFatal @@ -217,6 +219,7 @@ object Phases { private var myCountOuterAccessesPhase: Phase = _ private var myFlattenPhase: Phase = _ private var myGenBCodePhase: Phase = _ + private var myCheckCapturesPhase: Phase = _ final def parserPhase: Phase = myParserPhase final def typerPhase: Phase = myTyperPhase @@ -240,6 +243,7 @@ object Phases { final def countOuterAccessesPhase = myCountOuterAccessesPhase final def flattenPhase: Phase = myFlattenPhase final def genBCodePhase: Phase = myGenBCodePhase + final def checkCapturesPhase: Phase = myCheckCapturesPhase private def setSpecificPhases() = { def phaseOfClass(pclass: Class[?]) = phases.find(pclass.isInstance).getOrElse(NoPhase) @@ -265,7 +269,8 @@ object Phases { myFlattenPhase = phaseOfClass(classOf[Flatten]) myExplicitOuterPhase = phaseOfClass(classOf[ExplicitOuter]) myGettersPhase = phaseOfClass(classOf[Getters]) - myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myCheckCapturesPhase = phaseOfClass(classOf[CheckCaptures]) } final def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id @@ -315,6 +320,10 @@ object Phases { unitCtx.compilationUnit } + /** Convert a compilation unit's tree to a string; can be overridden */ + def show(tree: untpd.Tree)(using Context): String = + tree.show(using ctx.withProperty(XprintMode, Some(()))) + def description: String = phaseName /** Output should be checkable by TreeChecker */ @@ -442,6 +451,7 @@ object Phases { def lambdaLiftPhase(using Context): Phase = ctx.base.lambdaLiftPhase def flattenPhase(using Context): Phase = ctx.base.flattenPhase def genBCodePhase(using Context): Phase = ctx.base.genBCodePhase + def checkCapturesPhase(using Context): Phase = ctx.base.checkCapturesPhase def unfusedPhases(using Context): Array[Phase] = ctx.base.phases diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index f8c70176482c..8ab97925ecaa 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -283,6 +283,7 @@ object StdNames { // Compiler-internal val ANYname: N = "" + val CAPTURE_ROOT: N = "*" val COMPANION: N = "" val CONSTRUCTOR: N = "" val STATIC_CONSTRUCTOR: N = "" @@ -370,6 +371,7 @@ object StdNames { val AppliedTypeTree: N = "AppliedTypeTree" val ArrayAnnotArg: N = "ArrayAnnotArg" val CAP: N = "CAP" + val ClassManifestFactory: N = "ClassManifestFactory" val Constant: N = "Constant" val ConstantType: N = "ConstantType" val Eql: N = "Eql" @@ -446,7 +448,6 @@ object StdNames { val canEqual_ : N = "canEqual" val canEqualAny : N = "canEqualAny" val checkInitialized: N = "checkInitialized" - val ClassManifestFactory: N = "ClassManifestFactory" val classOf: N = "classOf" val classType: N = "classType" val clone_ : N = "clone" @@ -578,6 +579,7 @@ object StdNames { val reflectiveSelectable: N = "reflectiveSelectable" val reify : N = "reify" val releaseFence : N = "releaseFence" + val retains: N = "retains" val rootMirror : N = "rootMirror" val run: N = "run" val runOrElse: N = "runOrElse" diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index f00edcb189c6..b277f2cd8619 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -161,8 +161,9 @@ object Substituters: .mapOver(tp) } - final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap { + final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) + def inverse(tp: Type): Type = tp.subst(to, from) } final class Subst1Map(from: Symbol, to: Type)(using Context) extends DeepTypeMap { @@ -177,8 +178,9 @@ object Substituters: def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) } - final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap { + final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = substSym(tp, from, to, this)(using mapCtx) + def inverse(tp: Type) = tp.substSym(to, from) } final class SubstThisMap(from: ClassSymbol, to: Type)(using Context) extends DeepTypeMap { diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 2f7e3debfa6f..83299918ad61 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -24,6 +24,7 @@ import config.Config import reporting._ import collection.mutable import transform.TypeUtils._ +import cc.{CapturingType, derivedCapturingType} import scala.annotation.internal.sharable @@ -225,6 +226,8 @@ object SymDenotations { ensureCompleted(); myAnnotations } + final def annotationsUNSAFE(using Context): List[Annotation] = myAnnotations + /** Update the annotations of this denotation */ final def annotations_=(annots: List[Annotation]): Unit = myAnnotations = annots @@ -1507,8 +1510,7 @@ object SymDenotations { case tp: ExprType => hasSkolems(tp.resType) case tp: AppliedType => hasSkolems(tp.tycon) || tp.args.exists(hasSkolems) case tp: LambdaType => tp.paramInfos.exists(hasSkolems) || hasSkolems(tp.resType) - case tp: AndType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) - case tp: OrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) + case tp: AndOrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) case tp: AnnotatedType => hasSkolems(tp.parent) case _ => false } @@ -2164,6 +2166,9 @@ object SymDenotations { case tp: TypeParamRef => // uncachable, since baseType depends on context bounds recur(TypeComparer.bounds(tp).hi) + case CapturingType(parent, refs, _) => + tp.derivedCapturingType(recur(parent), refs) + case tp: TypeProxy => def computeTypeProxy = { val superTp = tp.superType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 4a0dc5f03410..0d6f137f3dd9 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -24,6 +24,7 @@ import typer.Applications.productSelectorTypes import reporting.trace import NullOpsDecorator._ import annotation.constructorOnly +import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing} /** Provides methods to compare types. */ @@ -319,6 +320,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling compareWild case tp2: LazyRef => isBottom(tp1) || !tp2.evaluating && recur(tp1, tp2.ref) + case CapturingType(_, _, _) => + secondTry case tp2: AnnotatedType if !tp2.isRefining => recur(tp1, tp2.parent) case tp2: ThisType => @@ -438,8 +441,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // See i859.scala for an example where we hit this case. tp2.isRef(AnyClass, skipRefined = false) || !tp1.evaluating && recur(tp1.ref, tp2) - case tp1: AnnotatedType if !tp1.isRefining => - recur(tp1.parent, tp2) case AndType(tp11, tp12) => if (tp11.stripTypeVar eq tp12.stripTypeVar) recur(tp11, tp2) else thirdTry @@ -483,7 +484,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // and then need to check that they are indeed supertypes of the original types // under -Ycheck. Test case is i7965.scala. - case tp1: MatchType => + case CapturingType(parent1, refs1, _) => + if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK then + recur(parent1, tp2) + else + thirdTry + case tp1: AnnotatedType if !tp1.isRefining => + recur(tp1.parent, tp2) + case tp1: MatchType => val reduced = tp1.reduced if (reduced.exists) recur(reduced, tp2) else thirdTry case _: FlexType => @@ -521,8 +529,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // Note: We would like to replace this by `if (tp1.hasHigherKind)` // but right now we cannot since some parts of the standard library rely on the // idiom that e.g. `List <: Any`. We have to bootstrap without scalac first. - if (cls2 eq AnyClass) return true - if (cls2 == defn.SingletonClass && tp1.isStable) return true + if cls2 eq AnyClass then return true + if cls2 == defn.SingletonClass && tp1.isStable then return true return tryBaseType(cls2) } else if (cls2.is(JavaDefined)) { @@ -591,6 +599,28 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubRefinements(tp1w.asInstanceOf[RefinedType], tp2, skipped2) && recur(tp1, skipped2) + def isSubInfo(info1: Type, info2: Type): Boolean = (info1, info2) match + case (info1: PolyType, info2: PolyType) => + sameLength(info1.paramNames, info2.paramNames) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case (info1: MethodType, info2: MethodType) => + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case _ => + isSubType(info1, info2) + + if ctx.phase == Phases.checkCapturesPhase then + if defn.isFunctionType(tp2) then + tp1.widenDealias match + case tp1: RefinedType => + return isSubInfo(tp1.refinedInfo, tp2.refinedInfo) + case _ => + else if tp2.parent.typeSymbol == defn.PolyFunctionClass then + tp1.member(nme.apply).info match + case info1: PolyType => + return isSubInfo(info1, tp2.refinedInfo) + case _ => + compareRefined case tp2: RecType => def compareRec = tp1.safeDealias match { @@ -721,13 +751,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def compareTypeBounds = tp1 match { case tp1 @ TypeBounds(lo1, hi1) => ((lo2 eq NothingType) || isSubType(lo2, lo1)) && - ((hi2 eq AnyType) && !hi1.isLambdaSub || (hi2 eq AnyKindType) || isSubType(hi1, hi2)) + ((hi2 eq AnyType) && !hi1.isLambdaSub + || (hi2 eq AnyKindType) + || isSubType(hi1, hi2)) case tp1: ClassInfo => tp2 contains tp1 case _ => false } compareTypeBounds + case CapturingType(parent2, _, _) => + recur(tp1, parent2) || fourthTry case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) @@ -774,6 +808,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp: AppliedType => isNullable(tp.tycon) case AndType(tp1, tp2) => isNullable(tp1) && isNullable(tp2) case OrType(tp1, tp2) => isNullable(tp1) || isNullable(tp2) + case CapturingType(tp1, _, _) => isNullable(tp1) case _ => false } val sym1 = tp1.symbol @@ -792,7 +827,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false } case _ => false - comparePaths || isSubType(tp1.underlying.widenExpr, tp2, approx.addLow) + comparePaths || { + var tp1w = tp1.underlying.widenExpr + tp1 match + case tp1: CaptureRef if tp1.isTracked => + val stripped = tp1w.stripCapturing + tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false) + case _ => + isSubType(tp1w, tp2, approx.addLow) + } case tp1: RefinedType => isNewSubType(tp1.parent) case tp1: RecType => @@ -1763,69 +1806,68 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def hasMatchingMember(name: Name, tp1: Type, tp2: RefinedType): Boolean = trace(i"hasMatchingMember($tp1 . $name :? ${tp2.refinedInfo}), mbr: ${tp1.member(name).info}", subtyping) { - def qualifies(m: SingleDenotation): Boolean = - // If the member is an abstract type and the prefix is a path, compare the member itself - // instead of its bounds. This case is needed situations like: - // - // class C { type T } - // val foo: C - // foo.type <: C { type T {= , <: , >:} foo.T } - // - // or like: - // - // class C[T] - // C[?] <: C[TV] - // - // where TV is a type variable. See i2397.scala for an example of the latter. - def matchAbstractTypeMember(info1: Type): Boolean = info1 match { - case TypeBounds(lo, hi) if lo ne hi => - tp2.refinedInfo match { - case rinfo2: TypeBounds if tp1.isStable => - val ref1 = tp1.widenExpr.select(name) - isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) - case _ => - false - } - case _ => false - } + // If the member is an abstract type and the prefix is a path, compare the member itself + // instead of its bounds. This case is needed situations like: + // + // class C { type T } + // val foo: C + // foo.type <: C { type T {= , <: , >:} foo.T } + // + // or like: + // + // class C[T] + // C[?] <: C[TV] + // + // where TV is a type variable. See i2397.scala for an example of the latter. + def matchAbstractTypeMember(info1: Type): Boolean = info1 match { + case TypeBounds(lo, hi) if lo ne hi => + tp2.refinedInfo match { + case rinfo2: TypeBounds if tp1.isStable => + val ref1 = tp1.widenExpr.select(name) + isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) + case _ => + false + } + case _ => false + } - // An additional check for type member matching: If the refinement of the - // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. - // then the symbol referred to in the subtype must have a signature that coincides - // in its parameters with the refinement's signature. The reason for the check - // is that if the refinement does not refer to a member symbol, we will have to - // resort to reflection to invoke the member. And Java reflection needs to know exact - // erased parameter types. See neg/i12211.scala. Other reflection algorithms could - // conceivably dispatch without knowning precise parameter signatures. One can signal - // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, - // in which case the signature test is elided. - def sigsOK(symInfo: Type, info2: Type) = - tp2.underlyingClassRef(refinementOK = true).member(name).exists - || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) - || symInfo.isInstanceOf[MethodType] - && symInfo.signature.consistentParams(info2.signature) - - // A relaxed version of isSubType, which compares method types - // under the standard arrow rule which is contravarient in the parameter types, - // but under the condition that signatures might have to match (see sigsOK) - // This relaxed version is needed to correctly compare dependent function types. - // See pos/i12211.scala. - def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = - info2 match - case info2: MethodType => - info1 match - case info1: MethodType => - val symInfo1 = symInfo.stripPoly - matchingMethodParams(info1, info2, precise = false) - && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) - && sigsOK(symInfo1, info2) - case _ => isSubType(info1, info2) - case _ => isSubType(info1, info2) + // An additional check for type member matching: If the refinement of the + // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. + // then the symbol referred to in the subtype must have a signature that coincides + // in its parameters with the refinement's signature. The reason for the check + // is that if the refinement does not refer to a member symbol, we will have to + // resort to reflection to invoke the member. And Java reflection needs to know exact + // erased parameter types. See neg/i12211.scala. Other reflection algorithms could + // conceivably dispatch without knowning precise parameter signatures. One can signal + // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, + // in which case the signature test is elided. + def sigsOK(symInfo: Type, info2: Type) = + tp2.underlyingClassRef(refinementOK = true).member(name).exists + || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) + || symInfo.isInstanceOf[MethodType] + && symInfo.signature.consistentParams(info2.signature) + + // A relaxed version of isSubType, which compares method types + // under the standard arrow rule which is contravarient in the parameter types, + // but under the condition that signatures might have to match (see sigsOK) + // This relaxed version is needed to correctly compare dependent function types. + // See pos/i12211.scala. + def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = + info2 match + case info2: MethodType => + info1 match + case info1: MethodType => + val symInfo1 = symInfo.stripPoly + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) + && sigsOK(symInfo1, info2) + case _ => isSubType(info1, info2) + case _ => isSubType(info1, info2) + def qualifies(m: SingleDenotation): Boolean = val info1 = m.info.widenExpr isSubInfo(info1, tp2.refinedInfo.widenExpr, m.symbol.info.orElse(info1)) || matchAbstractTypeMember(m.info) - end qualifies tp1.member(name) match // inlined hasAltWith for performance case mbr: SingleDenotation => qualifies(mbr) @@ -1950,8 +1992,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case formal2 :: rest2 => val formal2a = if (tp2.isParamDependent) formal2.subst(tp2, tp1) else formal2 val paramsMatch = - if precise then isSameTypeWhenFrozen(formal1, formal2a) - else isSubTypeWhenFrozen(formal2a, formal1) + if precise then + isSameTypeWhenFrozen(formal1, formal2a) + else if ctx.phase == Phases.checkCapturesPhase then + isSubType(formal2a, formal1) + else + isSubTypeWhenFrozen(formal2a, formal1) paramsMatch && loop(rest1, rest2) case nil => false @@ -2354,6 +2400,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } case tp1: TypeVar if tp1.isInstantiated => tp1.underlying & tp2 + case CapturingType(parent1, refs1, _) => + if subCaptures(tp2.captureSet, refs1, frozenConstraint).isOK then + parent1 & tp2 + else + tp1.derivedCapturingType(parent1 & tp2, refs1) case tp1: AnnotatedType if !tp1.isRefining => tp1.underlying & tp2 case _ => @@ -2416,6 +2467,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling false } + protected def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + refs1.subCaptures(refs2, frozen) + // ----------- Diagnostics -------------------------------------------------- /** A hook for showing subtype traces. Overridden in ExplainingTypeComparer */ @@ -2681,6 +2735,7 @@ object TypeComparer { else res match case ClassInfo(_, cls, _, _, _) => cls.showLocated case bounds: TypeBounds => i"type bounds [$bounds]" + case CaptureSet.CompareResult.OK => "OK" case res: printing.Showable => res.show case _ => String.valueOf(res) @@ -3006,5 +3061,10 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.addConstraint(param, bound, fromBelow) } + override def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + traceIndented(i"subcaptures $refs1 <:< $refs2 ${if frozen then "frozen" else ""}") { + super.subCaptures(refs1, refs2, frozen) + } + def lastTrace(header: String): String = header + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala index c9ca98f65f5e..9067d0c87142 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala @@ -73,6 +73,7 @@ class RecursionOverflow(val op: String, details: => String, val previous: Throwa s"""Recursion limit exceeded. |Maybe there is an illegal cyclic reference? |If that's not the case, you could also try to increase the stacksize using the -Xss JVM option. + |For the unprocessed stack trace, compile with -Yno-decode-stacktraces. |A recurring operation is (inner to outer): |${opsString(mostCommon)}""".stripMargin } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 2d5c2a6da88a..abe2123cb609 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -19,6 +19,8 @@ import typer.ForceDegree import typer.Inferencing._ import typer.IfBottom import reporting.TestingReporter +import cc.{CapturingType, derivedCapturingType, CaptureSet} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -164,6 +166,12 @@ object TypeOps: // with Nulls (which have no base classes). Under -Yexplicit-nulls, we take // corrective steps, so no widening is wanted. simplify(l, theMap) | simplify(r, theMap) + case CapturingType(parent, refs, _) => + if !ctx.mode.is(Mode.Type) + && refs.subCaptures(parent.captureSet, frozen = true).isOK then + simplify(parent, theMap) + else + mapOver case tp @ AnnotatedType(parent, annot) => val parent1 = simplify(parent, theMap) if annot.symbol == defn.UncheckedVarianceAnnot @@ -273,15 +281,23 @@ object TypeOps: case _ => false } - // Step 1: Get RecTypes and ErrorTypes out of the way, + // Step 1: Get RecTypes and ErrorTypes and CapturingTypes out of the way, tp1 match { - case tp1: RecType => return tp1.rebind(approximateOr(tp1.parent, tp2)) - case err: ErrorType => return err + case tp1: RecType => + return tp1.rebind(approximateOr(tp1.parent, tp2)) + case CapturingType(parent1, refs1, _) => + return tp1.derivedCapturingType(approximateOr(parent1, tp2), refs1) + case err: ErrorType => + return err case _ => } tp2 match { - case tp2: RecType => return tp2.rebind(approximateOr(tp1, tp2.parent)) - case err: ErrorType => return err + case tp2: RecType => + return tp2.rebind(approximateOr(tp1, tp2.parent)) + case CapturingType(parent2, refs2, _) => + return tp2.derivedCapturingType(approximateOr(tp1, parent2), refs2) + case err: ErrorType => + return err case _ => } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 769defc3e42c..a64cab59b2aa 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,6 +38,8 @@ import scala.util.hashing.{ MurmurHash3 => hashing } import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference +import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -67,7 +69,7 @@ object Types { * | | +--- SkolemType * | +- TypeParamRef * | +- RefinedOrRecType -+-- RefinedType - * | | -+-- RecType + * | | +-- RecType * | +- AppliedType * | +- TypeBounds * | +- ExprType @@ -187,7 +189,7 @@ object Types { * It makes no sense for it to be an alias type because isRef would always * return false in that case. */ - def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = stripped match { + def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = this match { case this1: TypeRef => this1.info match { // see comment in Namer#typeDefSig case TypeAlias(tp) => tp.isRef(sym, skipRefined) @@ -199,6 +201,12 @@ object Types { val this2 = this1.dealias if (this2 ne this1) this2.isRef(sym, skipRefined) else this1.underlying.isRef(sym, skipRefined) + case this1: TypeVar => + this1.instanceOpt.isRef(sym, skipRefined) + case this1: AnnotatedType => + this1 match + case CapturingType(_, _, _) => false + case _ => this1.parent.isRef(sym, skipRefined) case _ => false } @@ -371,6 +379,7 @@ object Types { case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference) case WildcardType(optBounds) => optBounds.unusableForInference + case CapturingType(parent, refs, _) => parent.unusableForInference || refs.elems.exists(_.unusableForInference) case _: ErrorType => true case _ => false @@ -1185,9 +1194,13 @@ object Types { */ def stripAnnots(using Context): Type = this - /** Strip TypeVars and Annotation wrappers */ + /** Strip TypeVars and Annotation and CapturingType wrappers */ def stripped(using Context): Type = this + def strippedDealias(using Context): Type = + val tp1 = stripped.dealias + if tp1 ne this then tp1.strippedDealias else this + def rewrapAnnots(tp: Type)(using Context): Type = tp.stripTypeVar match { case AnnotatedType(tp1, annot) => AnnotatedType(rewrapAnnots(tp1), annot) case _ => this @@ -1379,8 +1392,13 @@ object Types { val tp1 = tp.instanceOpt if (tp1.exists) tp1.dealias1(keep, keepOpaques) else tp case tp: AnnotatedType => - val tp1 = tp.parent.dealias1(keep, keepOpaques) - if keep(tp) then tp.derivedAnnotatedType(tp1, tp.annot) else tp1 + val parent1 = tp.parent.dealias1(keep, keepOpaques) + tp match + case tp @ CapturingType(parent, refs, _) => + tp.derivedCapturingType(parent1, refs) + case _ => + if keep(tp) then tp.derivedAnnotatedType(parent1, tp.annot) + else parent1 case tp: LazyRef => tp.ref.dealias1(keep, keepOpaques) case _ => this @@ -1479,7 +1497,7 @@ object Types { if (tp.tycon.isLambdaSub) NoType else tp.superType.underlyingClassRef(refinementOK) case tp: AnnotatedType => - tp.underlying.underlyingClassRef(refinementOK) + tp.parent.underlyingClassRef(refinementOK) case tp: RefinedType => if (refinementOK) tp.underlying.underlyingClassRef(refinementOK) else NoType case tp: RecType => @@ -1522,6 +1540,8 @@ object Types { case _ => if (isRepeatedParam) this.argTypesHi.head else this } + def captureSet(using Context): CaptureSet = CaptureSet.ofType(this) + // ----- Normalizing typerefs over refined types ---------------------------- /** If this normalizes* to a refinement type that has a refinement for `name` (which might be followed @@ -1805,7 +1825,7 @@ object Types { * @param dropLast The number of trailing parameters that should be dropped * when forming the function type. */ - def toFunctionType(isJava: Boolean, dropLast: Int = 0)(using Context): Type = this match { + def toFunctionType(isJava: Boolean, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match { case mt: MethodType if !mt.isParamDependent => val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast val isContextual = mt.isContextualMethod && !ctx.erasedTypes @@ -1817,7 +1837,7 @@ object Types { val funType = defn.FunctionOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), result1, isContextual, isErased) - if (mt.isResultDependent) RefinedType(funType, nme.apply, mt) + if alwaysDependent || mt.isResultDependent then RefinedType(funType, nme.apply, mt) else funType } @@ -1849,6 +1869,16 @@ object Types { case _ => this } + def capturing(ref: CaptureRef)(using Context): Type = + if captureSet.accountsFor(ref) then this + else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing) + + def capturing(cs: CaptureSet)(using Context): Type = + if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this + else this match + case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs) + case _ => CapturingType(this, cs, this.isBoxedCapturing) + /** The set of distinct symbols referred to by this type, after all aliases are expanded */ def coveringSet(using Context): Set[Symbol] = (new CoveringSetAccumulator).apply(Set.empty[Symbol], this) @@ -2029,6 +2059,40 @@ object Types { def isOverloaded(using Context): Boolean = false } + /** A trait for references in CaptureSets. These can be NamedTypes, ThisTypes or ParamRefs */ + trait CaptureRef extends SingletonType: + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + private var mySingletonCaptureSet: CaptureSet.Const = null + + def canBeTracked(using Context): Boolean + final def isTracked(using Context): Boolean = canBeTracked && !captureSetOfInfo.isAlwaysEmpty + def isRootCapability(using Context): Boolean = false + def normalizedRef(using Context): CaptureRef = this + + def singletonCaptureSet(using Context): CaptureSet.Const = + if mySingletonCaptureSet == null then + mySingletonCaptureSet = CaptureSet(this.normalizedRef) + mySingletonCaptureSet + + def captureSetOfInfo(using Context): CaptureSet = + if ctx.runId == myCaptureSetRunId then myCaptureSet + else if myCaptureSet eq CaptureSet.Pending then CaptureSet.empty + else + myCaptureSet = CaptureSet.Pending + val computed = CaptureSet.ofInfo(this) + if ctx.phase != Phases.checkCapturesPhase || underlying.isProvisional then + myCaptureSet = null + else + myCaptureSet = computed + myCaptureSetRunId = ctx.runId + computed + + override def captureSet(using Context): CaptureSet = + val cs = captureSetOfInfo + if canBeTracked && !cs.isAlwaysEmpty then singletonCaptureSet else cs + end CaptureRef + /** A trait for types that bind other types that refer to them. * Instances are: LambdaType, RecType. */ @@ -2076,7 +2140,7 @@ object Types { // --- NamedTypes ------------------------------------------------------------------ - abstract class NamedType extends CachedProxyType with ValueType { self => + abstract class NamedType extends CachedProxyType, ValueType { self => type ThisType >: this.type <: NamedType type ThisName <: Name @@ -2095,6 +2159,9 @@ object Types { private var mySignature: Signature = _ private var mySignatureRunId: Int = NoRunId + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + // Invariants: // (1) checkedPeriod != Nowhere => lastDenotation != null // (2) lastDenotation != null => lastSymbol != null @@ -2447,7 +2514,7 @@ object Types { val tparam = symbol val cls = tparam.owner val base = pre.baseType(cls) - base match { + base.stripped match { case AppliedType(_, allArgs) => var tparams = cls.typeParams var args = allArgs @@ -2627,7 +2694,7 @@ object Types { */ abstract case class TermRef(override val prefix: Type, private var myDesignator: Designator) - extends NamedType with SingletonType with ImplicitRef { + extends NamedType, ImplicitRef, CaptureRef { type ThisType = TermRef type ThisName = TermName @@ -2651,6 +2718,25 @@ object Types { def implicitName(using Context): TermName = name def underlyingRef: TermRef = this + + /** A term reference can be tracked if it is a local term ref to a value + * or a method term parameter. References to term parameters of classes + * cannot be tracked individually. + * They are subsumed in the capture sets of the enclosing class. + * TODO: ^^^ What avout call-by-name? + */ + def canBeTracked(using Context) = + ((prefix eq NoPrefix) + || symbol.is(ParamAccessor) && (prefix eq symbol.owner.thisType) + || symbol.hasAnnotation(defn.AbilityAnnot) + || isRootCapability + ) && !symbol.is(Method) + + override def isRootCapability(using Context): Boolean = + name == nme.CAPTURE_ROOT && symbol == defn.captureRoot + + override def normalizedRef(using Context): CaptureRef = + if canBeTracked then symbol.termRef else this } abstract case class TypeRef(override val prefix: Type, @@ -2786,7 +2872,7 @@ object Types { * Note: we do not pass a class symbol directly, because symbols * do not survive runs whereas typerefs do. */ - abstract case class ThisType(tref: TypeRef) extends CachedProxyType with SingletonType { + abstract case class ThisType(tref: TypeRef) extends CachedProxyType, CaptureRef { def cls(using Context): ClassSymbol = tref.stableInRunSymbol match { case cls: ClassSymbol => cls case _ if ctx.mode.is(Mode.Interactive) => defn.AnyClass // was observed to happen in IDE mode @@ -2800,6 +2886,8 @@ object Types { // can happen in IDE if `cls` is stale } + def canBeTracked(using Context) = true + override def computeHash(bs: Binders): Int = doHash(bs, tref) override def eql(that: Type): Boolean = that match { @@ -3626,9 +3714,17 @@ object Types { case tp: AppliedType => tp.fold(status, compute(_, _, theAcc)) case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps - case AnnotatedType(parent, ann) => - if ann.refersToParamOf(thisLambdaType) then TrueDeps - else compute(status, parent, theAcc) + case tp: AnnotatedType => + tp match + case CapturingType(parent, refs, _) => + (compute(status, parent, theAcc) /: refs.elems) { + (s, ref) => ref match + case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps) + case _ => s + } + case _ => + if tp.annot.refersToParamOf(thisLambdaType) then TrueDeps + else compute(status, tp.parent, theAcc) case _: ThisType | _: BoundType | NoPrefix => status case _ => (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) @@ -3667,29 +3763,52 @@ object Types { /** Does result type contain references to parameters of this method type, * which cannot be eliminated by de-aliasing? */ - def isResultDependent(using Context): Boolean = dependencyStatus == TrueDeps + def isResultDependent(using Context): Boolean = + dependencyStatus == TrueDeps || dependencyStatus == CaptureDeps /** Does one of the parameter types contain references to earlier parameters * of this method type which cannot be eliminated by de-aliasing? */ def isParamDependent(using Context): Boolean = paramDependencyStatus == TrueDeps + /** Is there either a true or false type dependency, or does the result + * type capture a parameter? + */ + def isCaptureDependent(using Context) = dependencyStatus == CaptureDeps + def newParamRef(n: Int): TermParamRef = new TermParamRefImpl(this, n) /** The least supertype of `resultType` that does not contain parameter dependencies */ def nonDependentResultApprox(using Context): Type = - if (isResultDependent) { + if isResultDependent then val dropDependencies = new ApproximatingTypeMap { def apply(tp: Type) = tp match { case tp @ TermParamRef(`thisLambdaType`, _) => range(defn.NothingType, atVariance(1)(apply(tp.underlying))) + case CapturingType(parent, refs, boxed) => + val parent1 = this(parent) + val elems1 = refs.elems.filter { + case tp @ TermParamRef(`thisLambdaType`, _) => false + case _ => true + } + if elems1.size == refs.elems.size then + derivedCapturingType(tp, parent1, refs) + else + range( + CapturingType(parent1, CaptureSet(elems1), boxed), + CapturingType(parent1, CaptureSet.universal, boxed)) case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) => - mapOver(parent) + val parent1 = mapOver(parent) + if ann.symbol == defn.RetainsAnnot then + range( + AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation), + AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation)) + else + parent1 case _ => mapOver(tp) } } dropDependencies(resultType) - } else resultType } @@ -4060,9 +4179,10 @@ object Types { final val Unknown: DependencyStatus = 0 // not yet computed final val NoDeps: DependencyStatus = 1 // no dependent parameters found final val FalseDeps: DependencyStatus = 2 // all dependent parameters are prefixes of non-depended alias types - final val TrueDeps: DependencyStatus = 3 // some truly dependent parameters exist - final val StatusMask: DependencyStatus = 3 // the bits indicating actual dependency status - final val Provisional: DependencyStatus = 4 // set if dependency status can still change due to type variable instantiations + final val CaptureDeps: DependencyStatus = 3 + final val TrueDeps: DependencyStatus = 4 // some truly dependent parameters exist + final val StatusMask: DependencyStatus = 7 // the bits indicating actual dependency status + final val Provisional: DependencyStatus = 8 // set if dependency status can still change due to type variable instantiations } // ----- Type application: LambdaParam, AppliedType --------------------- @@ -4532,8 +4652,9 @@ object Types { /** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to * refer to `TermParamRef(binder, paramNum)`. */ - abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef with SingletonType { + abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef, CaptureRef { type BT = TermLambda + def canBeTracked(using Context) = true def kindString: String = "Term" def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum) } @@ -5181,7 +5302,7 @@ object Types { // ----- Annotated and Import types ----------------------------------------------- /** An annotated type tpe @ annot */ - abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType with ValueType { + abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType, ValueType { override def underlying(using Context): Type = parent @@ -5210,16 +5331,16 @@ object Types { // equals comes from case class; no matching override is needed override def computeHash(bs: Binders): Int = - doHash(bs, System.identityHashCode(annot), parent) + doHash(bs, annot.hash, parent) override def hashIsStable: Boolean = parent.hashIsStable override def eql(that: Type): Boolean = that match - case that: AnnotatedType => (parent eq that.parent) && (annot eq that.annot) + case that: AnnotatedType => (parent eq that.parent) && (annot eql that.annot) case _ => false override def iso(that: Any, bs: BinderPairs): Boolean = that match - case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eq that.annot) + case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eql that.annot) case _ => false } @@ -5230,6 +5351,7 @@ object Types { annots.foldLeft(underlying)(apply(_, _)) def apply(parent: Type, annot: Annotation)(using Context): AnnotatedType = unique(CachedAnnotatedType(parent, annot)) + end AnnotatedType // Special type objects and classes ----------------------------------------------------- @@ -5449,7 +5571,7 @@ object Types { /** Common base class of TypeMap and TypeAccumulator */ abstract class VariantTraversal: - protected[core] var variance: Int = 1 + protected[dotc] var variance: Int = 1 inline protected def atVariance[T](v: Int)(op: => T): T = { val saved = variance @@ -5475,6 +5597,24 @@ object Types { } end VariantTraversal + /** A supertrait for some typemaps that are bijections. Used for capture checking + * BiTypeMaps should map capture references to capture references. + */ + trait BiTypeMap extends TypeMap: + thisMap => + def inverse(tp: Type): Type + + def inverseTypeMap(using Context) = new BiTypeMap: + def apply(tp: Type) = thisMap.inverse(tp) + def inverse(tp: Type) = thisMap.apply(tp) + + def forward(ref: CaptureRef): CaptureRef = this(ref) match + case result: CaptureRef if result.canBeTracked => result + + def backward(ref: CaptureRef): CaptureRef = inverse(ref) match + case result: CaptureRef if result.canBeTracked => result + end BiTypeMap + abstract class TypeMap(implicit protected var mapCtx: Context) extends VariantTraversal with (Type => Type) { thisMap => @@ -5502,6 +5642,8 @@ object Types { tp.derivedMatchType(bound, scrutinee, cases) protected def derivedAnnotatedType(tp: AnnotatedType, underlying: Type, annot: Annotation): Type = tp.derivedAnnotatedType(underlying, annot) + protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + tp.derivedCapturingType(parent, refs) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -5537,6 +5679,12 @@ object Types { def isRange(tp: Type): Boolean = tp.isInstanceOf[Range] + protected def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + val saved = variance + variance = v + try derivedCapturingType(tp, this(parent), refs.map(this)) + finally variance = saved + /** Map this function over given type */ def mapOver(tp: Type): Type = { record(s"TypeMap mapOver ${getClass}") @@ -5578,6 +5726,9 @@ object Types { case tp: ExprType => derivedExprType(tp, this(tp.resultType)) + case CapturingType(parent, refs, _) => + mapCapturingType(tp, parent, refs, variance) + case tp @ AnnotatedType(underlying, annot) => val underlying1 = this(underlying) val annot1 = annot.mapWith(this) @@ -5905,6 +6056,13 @@ object Types { if (underlying.isExactlyNothing) underlying else tp.derivedAnnotatedType(underlying, annot) } + override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + parent match // ^^^ handle ranges in capture sets as well + case Range(lo, hi) => + range(derivedCapturingType(tp, lo, refs), derivedCapturingType(tp, hi, refs)) + case _ => + tp.derivedCapturingType(parent, refs) + override protected def derivedWildcardType(tp: WildcardType, bounds: Type): WildcardType = tp.derivedWildcardType(rangeToBounds(bounds)) @@ -5945,6 +6103,12 @@ object Types { tp.derivedLambdaType(tp.paramNames, formals, restpe) } + override def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + if v == 0 then + range(mapCapturingType(tp, parent, refs, -1), mapCapturingType(tp, parent, refs, 1)) + else + super.mapCapturingType(tp, parent, refs, v) + protected def reapply(tp: Type): Type = apply(tp) } @@ -6042,6 +6206,9 @@ object Types { val x2 = atVariance(0)(this(x1, tp.scrutinee)) foldOver(x2, tp.cases) + case CapturingType(parent, refs, _) => + (this(x, parent) /: refs.elems)(this) + case AnnotatedType(underlying, annot) => this(applyToAnnot(x, annot), underlying) diff --git a/compiler/src/dotty/tools/dotc/core/Variances.scala b/compiler/src/dotty/tools/dotc/core/Variances.scala index 122c7a10e4b7..44dda6b0077e 100644 --- a/compiler/src/dotty/tools/dotc/core/Variances.scala +++ b/compiler/src/dotty/tools/dotc/core/Variances.scala @@ -4,6 +4,7 @@ package core import Types._, Contexts._, Flags._, Symbols._, Annotations._ import TypeApplications.TypeParamInfo import Decorators._ +import cc.CapturingType object Variances { @@ -99,6 +100,8 @@ object Variances { v } varianceInArgs(varianceInType(tycon)(tparam), args, tycon.typeParams) + case CapturingType(tp, _, _) => + varianceInType(tp)(tparam) case AnnotatedType(tp, annot) => varianceInType(tp)(tparam) & varianceInAnnot(annot)(tparam) case AndType(tp1, tp2) => diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index af186e825591..a217b76944fd 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -821,7 +821,7 @@ class TreeUnpickler(reader: TastyReader, def TypeDef(rhs: Tree) = ta.assignType(untpd.TypeDef(sym.name.asTypeName, rhs), sym) - def ta = ctx.typeAssigner + def ta = ctx.typeAssigner val name = readName() pickling.println(s"reading def of $name at $start") @@ -1263,11 +1263,9 @@ class TreeUnpickler(reader: TastyReader, // types. This came up in #137 of collection strawman. val tycon = readTpt() val args = until(end)(readTpt()) - val ownType = - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.safeAppliedTo(args.tpes) - untpd.AppliedTypeTree(tycon, args).withType(ownType) + val tree = untpd.AppliedTypeTree(tycon, args) + val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes)) + tree.withType(ownType) case ANNOTATEDtpt => Annotated(readTpt(), readTerm()) case LAMBDAtpt => diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 3ffef90057ef..2df71a25766a 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -894,6 +894,24 @@ object Parsers { } } + def followingIsCaptureSet(): Boolean = + val lookahead = in.LookaheadScanner() + def recur(): Boolean = + (lookahead.isIdent || lookahead.token == THIS) && { + lookahead.nextToken() + if lookahead.token == COMMA then + lookahead.nextToken() + recur() + else + lookahead.token == RBRACE && { + lookahead.nextToken() + canStartInfixTypeTokens.contains(lookahead.token) + || lookahead.token == LBRACKET + } + } + lookahead.nextToken() + recur() + /* --------- OPERAND/OPERATOR STACK --------------------------------------- */ var opStack: List[OpInfo] = Nil @@ -1334,17 +1352,25 @@ object Parsers { case _ => false } + /** CaptureRef ::= ident | `this` + */ + def captureRef(): Tree = + if in.token == THIS then simpleRef() else termIdent() + /** Type ::= FunType * | HkTypeParamClause ‘=>>’ Type * | FunParamClause ‘=>>’ Type * | MatchType * | InfixType + * | CaptureSet Type * FunType ::= (MonoFunType | PolyFunType) * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type * PolyFunType ::= HKTypeParamClause '=>' Type * FunTypeArgs ::= InfixType * | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)' * | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')' + * CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` + * CaptureRef ::= Ident */ def typ(): Tree = { val start = in.offset @@ -1450,6 +1476,10 @@ object Parsers { } else { accept(TLARROW); typ() } } + else if in.token == LBRACE && followingIsCaptureSet() then + val refs = inBraces { commaSeparated(captureRef) } + val t = typ() + CapturingTypeTree(refs, t) else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() @@ -1518,7 +1548,7 @@ object Parsers { def infixType(): Tree = infixTypeRest(refinedType()) def infixTypeRest(t: Tree): Tree = - infixOps(t, canStartTypeTokens, refinedTypeFn, Location.ElseWhere, + infixOps(t, canStartInfixTypeTokens, refinedTypeFn, Location.ElseWhere, isType = true, isOperator = !followingIsVararg()) @@ -3168,7 +3198,7 @@ object Parsers { ImportSelector( atSpan(in.skipToken()) { Ident(nme.EMPTY) }, bound = - if canStartTypeTokens.contains(in.token) then rejectWildcardType(infixType()) + if canStartInfixTypeTokens.contains(in.token) then rejectWildcardType(infixType()) else EmptyTree) /** id [‘as’ (id | ‘_’) */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala index 55f428cef5a4..b4ef65a6035b 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala @@ -230,8 +230,8 @@ object Tokens extends TokensCommon { final val canStartExprTokens2: TokenSet = canStartExprTokens3 | BitSet(DO) - final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( - THIS, SUPER, USCORE, LPAREN, AT) + final val canStartInfixTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( + THIS, SUPER, USCORE, LPAREN, LBRACE, AT) final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index d2efbeff2901..e2513ec7b9df 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -14,13 +14,16 @@ import Variances.varianceSign import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch +import config.Config +import cc.{CapturingType, CaptureSet} class PlainPrinter(_ctx: Context) extends Printer { + /** The context of all public methods in Printer and subclasses. * Overridden in RefinedPrinter. */ - protected def curCtx: Context = _ctx.addMode(Mode.Printing) - protected given [DummyToEnforceDef]: Context = curCtx + def printerContext: Context = _ctx.addMode(Mode.Printing) + protected given [DummyToEnforceDef]: Context = printerContext protected def printDebug = ctx.settings.YprintDebug.value @@ -194,6 +197,22 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close + case CapturingType(parent, refs, boxed) => + def box = Str("box ") provided boxed + if printDebug && !refs.isConst then + changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent)) + else if ctx.settings.YccDebug.value then + changePrec(GlobalPrec)(box ~ refs.toText(this) ~ " " ~ toText(parent)) + else if !refs.isConst && refs.elems.isEmpty then + changePrec(GlobalPrec)("?" ~ " " ~ toText(parent)) + else if Config.printCaptureSetsAsPrefix then + changePrec(GlobalPrec)( + box ~ "{" + ~ Text(refs.elems.toList.map(toTextCaptureRef), ", ") + ~ "} " + ~ toText(parent)) + else + changePrec(InfixPrec)(toText(parent) ~ " retains " ~ box ~ toText(refs.toRetainsTypeArg)) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => "" // do not print previously reported error message because they may try to print this error type again recuresevely case tp: ErrorType => @@ -325,7 +344,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case tp @ ConstantType(value) => toText(value) case pref: TermParamRef => - nameString(pref.binder.paramNames(pref.paramNum)) + nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder) case tp: RecThis => val idx = openRecs.reverse.indexOf(tp.binder) if (idx >= 0) selfRecName(idx + 1) @@ -346,6 +365,11 @@ class PlainPrinter(_ctx: Context) extends Printer { } } + def toTextCaptureRef(tp: Type): Text = + homogenize(tp) match + case tp: SingletonType => toTextRef(tp) + case _ => toText(tp) + protected def isOmittablePrefix(sym: Symbol): Boolean = defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym) diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index 550bdb94af4f..b883b6be805b 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -6,7 +6,7 @@ import core._ import Texts._, ast.Trees._ import Types.{Type, SingletonType, LambdaParam}, Symbols.Symbol, Scopes.Scope, Constants.Constant, - Names.Name, Denotations._, Annotations.Annotation + Names.Name, Denotations._, Annotations.Annotation, Contexts.Context import typer.Implicits.SearchResult import util.SourcePosition import typer.ImportInfo @@ -104,6 +104,9 @@ abstract class Printer { /** Textual representation of a prefix of some reference, ending in `.` or `#` */ def toTextPrefix(tp: Type): Text + /** Textual representation of a reference in a capture set */ + def toTextCaptureRef(tp: Type): Text + /** Textual representation of symbol's declaration */ def dclText(sym: Symbol): Text @@ -182,6 +185,9 @@ abstract class Printer { /** A plain printer without any embellishments */ def plain: Printer + + /** The context in which this printer operates */ + def printerContext: Context } object Printer { diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index cf5942a178f0..2fb1715d4cfc 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -34,11 +34,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { /** A stack of enclosing DefDef, TypeDef, or ClassDef, or ModuleDefs nodes */ private var enclosingDef: untpd.Tree = untpd.EmptyTree - private var myCtx: Context = super.curCtx + private var myCtx: Context = super.printerContext private var printPos = ctx.settings.YprintPos.value private val printLines = ctx.settings.printLines.value - override protected def curCtx: Context = myCtx + override def printerContext: Context = myCtx def withEnclosingDef(enclDef: Tree[? >: Untyped])(op: => Text): Text = { val savedCtx = myCtx @@ -164,10 +164,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(GlobalPrec) { "(" ~ keywordText("erased ").provided(info.isErasedMethod) - ~ ( if info.isParamDependent || info.isResultDependent - then paramsText(info) - else argsText(info.paramInfos) - ) + ~ paramsText(info) ~ ") " ~ arrow(info.isImplicitMethod) ~ " " @@ -245,9 +242,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty => // don't eta contract if the application would be printed specially toText(tycon) - case tp: RefinedType - if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass)) - && !printDebug => + case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug => toTextMethodAsFunction(tp.refinedInfo) case tp: TypeRef => if (tp.symbol.isAnonymousClass && !showUniqueIds) @@ -703,6 +698,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val (prefix, postfix) = if isTermHole then ("{{{ ", " }}}") else ("[[[ ", " ]]]") val argsText = toTextGlobal(args, ", ") prefix ~~ idx.toString ~~ "|" ~~ argsText ~~ postfix + case CapturingTypeTree(refs, parent) => + changePrec(GlobalPrec)("{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent)) case _ => tree.fallbackToText(this) } @@ -789,9 +786,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if mdef.hasType then Modifiers(mdef.symbol) else mdef.rawMods private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers( - sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), + sym.flagsUNSAFE & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY, - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot @@ -996,13 +993,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else if (suppressKw) PrintableFlags(isType) &~ Private else PrintableFlags(isType) if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= GivenOrImplicit // drop implicit/given from classes - val rawFlags = if (sym.exists) sym.flags else mods.flags + val rawFlags = if (sym.exists) sym.flagsUNSAFE else mods.flags if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased val flags = rawFlags & flagMask var flagsText = toTextFlags(sym, flags) val annotTexts = if sym.exists then - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _)) Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw) diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 4a1efab782a1..96676f04ce99 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -287,7 +287,6 @@ import transform.SymUtils._ val treeStr = inTree.map(x => s"\nTree: ${x.show}").getOrElse("") treeStr + "\n" + super.explain - end TypeMismatch class NotAMember(site: Type, val name: Name, selected: String, addendum: => String = "")(using Context) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index 3e0e148f7101..dc415d98d87e 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -189,6 +189,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { private val byNameMarker = marker("ByName") private val matchMarker = marker("Match") private val superMarker = marker("Super") + private val retainsMarker = marker("Retains") /** Extract the API representation of a source file */ def apiSource(tree: Tree): Seq[api.ClassLike] = { diff --git a/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala new file mode 100644 index 000000000000..9a287b2dd1d9 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala @@ -0,0 +1,19 @@ +package dotty.tools.dotc +package transform + +import core.* +import Contexts.Context +import Phases.Phase + +/** A phase that can be inserted directly after a phase that cannot + * be checked, to enable a -Ycheck as soon as possible afterwards + */ +class EmptyPhase extends Phase: + + def phaseName: String = "dummy" + + override def isEnabled(using Context) = prev.isEnabled + + override def run(using Context) = () + +end EmptyPhase \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 76f89cb65757..a61b736a9cc1 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -14,15 +14,22 @@ import typer.ErrorReporting.err import typer.ProtoTypes.* import typer.TypeAssigner.seqLitType import typer.ConstFold +import NamerOps.methodType import config.Printers.recheckr import util.Property import StdNames.nme import reporting.trace +object Recheck: + + /** Attachment key for rechecked types of TypeTrees */ + private val RecheckedType = Property.Key[Type] + abstract class Recheck extends Phase, IdentityDenotTransformer: thisPhase => import ast.tpd.* + import Recheck.* def preRecheckPhase = this.prev.asInstanceOf[PreRecheck] @@ -36,12 +43,17 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: override def widenSkolems = true def run(using Context): Unit = - newRechecker().checkUnit(ctx.compilationUnit) + val rechecker = newRechecker() + rechecker.transformTypes.traverse(ctx.compilationUnit.tpdTree) + rechecker.checkUnit(ctx.compilationUnit) def newRechecker()(using Context): Rechecker class Rechecker(ictx: Context): - val ta = ictx.typeAssigner + private val ta = ictx.typeAssigner + private val keepTypes = inContext(ictx) { + ictx.settings.Xprint.value.containsPhase(thisPhase) + } extension (sym: Symbol) def updateInfo(newInfo: Type)(using Context): Unit = if sym.info ne newInfo then @@ -53,23 +65,102 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: else sym.flags ).installAfter(preRecheckPhase) - /** Hook to be overridden */ - protected def reinfer(tp: Type)(using Context): Type = tp - - def reinferResult(info: Type)(using Context): Type = info match - case info: MethodOrPoly => - info.derivedLambdaType(resType = reinferResult(info.resultType)) - case _ => - reinfer(info) + extension (tpe: Type) def rememberFor(tree: Tree)(using Context): Unit = + if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then + tree.putAttachment(RecheckedType, tpe) + + def knownType(tree: Tree) = + tree.attachmentOrElse(RecheckedType, tree.tpe) + + def isUpdated(sym: Symbol)(using Context) = + val symd = sym.denot + symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd) + + def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp + + object transformTypes extends TreeTraverser: + + // Substitute parameter symbols in `from` to paramRefs in corresponding + // method or poly types `to`. We use a single BiTypeMap to do everything. + class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) + extends DeepTypeMap, BiTypeMap: + + def apply(t: Type): Type = t match + case t: NamedType => + val sym = t.symbol + def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type = + def inner(from: List[Symbol], to: List[ParamRef]): Type = + if from.isEmpty then outer(froms.tail, tos.tail) + else if sym eq from.head then to.head + else inner(from.tail, to.tail) + if tos.isEmpty then t + else inner(froms.head, tos.head.paramRefs) + outer(from, to) + case _ => + mapOver(t) + + def inverse(t: Type): Type = t match + case t: ParamRef => + def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = + if from.isEmpty then t + else if t.binder eq from.head then to.head(t.paramNum).namedType + else recur(from.tail, to.tail) + recur(to, from) + case _ => + mapOver(t) + end SubstParams + + def traverse(tree: Tree)(using Context) = + traverseChildren(tree) + tree match - def enterDef(stat: Tree)(using Context): Unit = - val sym = stat.symbol - stat match - case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] => - sym.updateInfo(reinferResult(sym.info)) - case stat: Bind => - sym.updateInfo(reinferResult(sym.info)) - case _ => + case tree: TypeTree => + transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree]).rememberFor(tree) + case tree: ValOrDefDef => + val sym = tree.symbol + + // replace an existing symbol info with inferred types + def integrateRT( + info: Type, // symbol info to replace + psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info` + prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order + prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order + ): Type = + info match + case mt: MethodOrPoly => + val psyms = psymss.head + mt.companion(mt.paramNames)( + mt1 => + if !psyms.exists(isUpdated) && !mt.isParamDependent && prevLambdas.isEmpty then + mt.paramInfos + else + val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) + psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), + mt1 => + integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) + ) + case info: ExprType => + info.derivedExprType(resType = + integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) + case _ => + val restp = knownType(tree.tpt) + if prevLambdas.isEmpty then restp + else SubstParams(prevPsymss, prevLambdas)(restp) + + if tree.tpt.hasAttachment(RecheckedType) && !sym.isConstructor then + val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) + .showing(i"update info $sym: ${sym.info} --> $result", recheckr) + if newInfo ne sym.info then + val completer = new LazyType: + def complete(denot: SymDenotation)(using Context) = + denot.info = newInfo + recheckDef(tree, sym) + sym.updateInfo(completer) + case tree: Bind => + val sym = tree.symbol + sym.updateInfo(transformType(sym.info, inferred = true)) + case _ => + end transformTypes def constFold(tree: Tree, tp: Type)(using Context): Type = val tree1 = tree.withType(tp) @@ -90,10 +181,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: excluded = if tree.symbol.is(Private) then EmptyFlags else Private ).suchThat(tree.symbol ==) constFold(tree, qualType.select(name, mbr)) + //.showing(i"recheck select $qualType . $name : ${mbr.symbol.info} = $result") def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match case Bind(name, body) => - enterDef(tree) recheck(body, pt) val sym = tree.symbol if sym.isType then sym.typeRef else sym.info @@ -104,16 +195,13 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val exprType = recheck(expr, defn.UnitType) bindType - def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type = - if !tree.rhs.isEmpty then recheck(tree.rhs, tree.symbol.info) - sym.termRef + def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) - def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type = - tree.paramss.foreach(_.foreach(enterDef)) - val rhsCtx = linkConstructorParams(sym) + def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - recheck(tree.rhs, tree.symbol.localReturnType)(using rhsCtx) - sym.termRef + inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) @@ -134,6 +222,11 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case _ => mapOver(t) formals.mapConserve(tm) + /** Hook for method type instantiation + */ + protected def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + mt.instantiate(argTypes) + def recheckApply(tree: Apply, pt: Type)(using Context): Type = recheck(tree.fun).widen match case fntpe: MethodType => @@ -153,7 +246,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: assert(formals.isEmpty) Nil val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs) - constFold(tree, fntpe.instantiate(argTypes)) + constFold(tree, instantiate(fntpe, argTypes, tree.fun.symbol)) def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type = recheck(tree.fun).widen match @@ -174,7 +267,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type = recheckStats(stats) - val exprType = recheck(expr, pt.dropIfProto) + val exprType = recheck(expr) + // The expected type `pt` is not propagated. Doing so would allow variables in the + // expected type to contain references to local symbols of the block, so the + // local symbols could escape that way. TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) def recheckBlock(tree: Block, pt: Type)(using Context): Type = @@ -195,10 +291,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckMatch(tree: Match, pt: Type)(using Context): Type = val selectorType = recheck(tree.selector) - val casesTypes = tree.cases.map(recheck(_, selectorType.widen, pt)) + val casesTypes = tree.cases.map(recheckCase(_, selectorType.widen, pt)) TypeComparer.lub(casesTypes) - def recheck(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = + def recheckCase(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = recheck(tree.pat, selType) recheck(tree.guard, defn.BooleanType) recheck(tree.body, pt) @@ -214,7 +310,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckTry(tree: Try, pt: Type)(using Context): Type = val bodyType = recheck(tree.expr, pt) - val casesTypes = tree.cases.map(recheck(_, defn.ThrowableType, pt)) + val casesTypes = tree.cases.map(recheckCase(_, defn.ThrowableType, pt)) val finalizerType = recheck(tree.finalizer, defn.UnitType) TypeComparer.lub(bodyType :: casesTypes) @@ -227,9 +323,8 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val elemTypes = tree.elems.map(recheck(_, elemProto)) seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes)) - def recheckTypeTree(tree: TypeTree)(using Context): Type = tree match - case tree: InferredTypeTree => reinfer(tree.tpe) - case _ => tree.tpe + def recheckTypeTree(tree: TypeTree)(using Context): Type = + knownType(tree) def recheckAnnotated(tree: Annotated)(using Context): Type = tree.tpe match @@ -246,14 +341,20 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: NoType def recheckStats(stats: List[Tree])(using Context): Unit = - stats.foreach(enterDef) stats.foreach(recheck(_)) + def recheckDef(tree: ValOrDefDef, sym: Symbol)(using Context): Unit = + inContext(ctx.localContext(tree, sym)) { + tree match + case tree: ValDef => recheckValDef(tree, sym) + case tree: DefDef => recheckDefDef(tree, sym) + } + /** Recheck tree without adapting it, returning its new type. * @param tree the original tree * @param pt the expected result type */ - def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + def recheckStart(tree: Tree, pt: Type = WildcardType)(using Context): Type = def recheckNamed(tree: NameTree, pt: Type)(using Context): Type = val sym = tree.symbol @@ -261,11 +362,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: Ident => recheckIdent(tree) case tree: Select => recheckSelect(tree) case tree: Bind => recheckBind(tree, pt) - case tree: ValDef => + case tree: ValOrDefDef => if tree.isEmpty then NoType - else recheckValDef(tree, sym)(using ctx.localContext(tree, sym)) - case tree: DefDef => - recheckDefDef(tree, sym)(using ctx.localContext(tree, sym)) + else + if isUpdated(sym) then sym.ensureCompleted() + else recheckDef(tree, sym) + sym.termRef case tree: TypeDef => tree.rhs match case impl: Template => @@ -295,35 +397,61 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: PackageDef => recheckPackageDef(tree) case tree: Thicket => defn.NothingType - try - val result = tree match - case tree: NameTree => recheckNamed(tree, pt) - case tree => recheckUnnamed(tree, pt) - checkConforms(result, pt, tree) - result - catch case ex: Exception => - println(i"error while rechecking $tree") - throw ex - } - end recheck + tree match + case tree: NameTree => recheckNamed(tree, pt) + case tree => recheckUnnamed(tree, pt) + end recheckStart + + def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = + checkConforms(tpe, pt, tree) + if keepTypes then tpe.rememberFor(tree) + tpe + + def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + try recheckFinish(recheckStart(tree, pt), tree, pt) + catch case ex: Exception => + println(i"error while rechecking $tree") + throw ex + } + + private val debugSuccesses = false def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match - case _: DefTree | EmptyTree | _: TypeTree => + case _: DefTree | EmptyTree | _: TypeTree | _: Closure => + // Don't report closure nodes, since their span is a point; wait instead + // for enclosing block to preduce an error case _ => val actual = tpe.widenExpr val expected = pt.widenExpr + //println(i"check conforms $actual <:< $expected") val isCompatible = actual <:< expected || expected.isRepeatedParam && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) if !isCompatible then - println(i"err at ${ctx.phase}") - err.typeMismatch(tree.withType(tpe), pt) + err.typeMismatch(tree.withType(tpe), expected) + else if debugSuccesses then + tree match + case _: Ident => + println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") + case _ => def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) end Rechecker + + override def show(tree: untpd.Tree)(using Context): String = + val addRecheckedTypes = new TreeMap: + override def transform(tree: Tree)(using Context): Tree = + val tree1 = super.transform(tree) + tree.getAttachment(RecheckedType) match + case Some(tpe) => tree1.withType(tpe) + case None => tree1 + atPhase(thisPhase) { + super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree])) + } end Recheck class TestRecheck extends Recheck: diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 29fd1adb6688..044ea11eb27e 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -375,14 +375,14 @@ class TreeChecker extends Phase with SymTransformer { val tpe = tree.typeOpt // Polymorphic apply methods stay structural until Erasure - val isPolyFunctionApply = (tree.name eq nme.apply) && (tree.qualifier.typeOpt <:< defn.PolyFunctionType) + val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass) // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then val denot = tree.denot assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation") - assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol") + assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, qualifier type = ${tree.qualifier.typeOpt}") val sym = tree.symbol val symIsFixed = tpe match { diff --git a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala index 6be58352e6dc..26bea001d1eb 100644 --- a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala @@ -70,7 +70,7 @@ class TryCatchPatterns extends MiniPhase { case _ => isDefaultCase(cdef) } - private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripAnnots match { + private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripped match { case tp @ TypeRef(pre, _) => (pre == NoPrefix || pre.typeSymbol.isStatic) && // Does not require outer class check !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 8ffe2198c4d9..7c5d34126bd9 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -148,7 +148,7 @@ object TypeTestsCasts { } case AndType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) case OrType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) - case AnnotatedType(t, _) => recur(X, t) + case tp: AnnotatedType => recur(X, tp.parent) case _: RefinedType => false case _ => true }) @@ -217,7 +217,7 @@ object TypeTestsCasts { * can be true in some cases. Issues a warning or an error otherwise. */ def checkSensical(foundClasses: List[Symbol])(using Context): Boolean = - def exprType = i"type ${expr.tpe.widen.stripAnnots}" + def exprType = i"type ${expr.tpe.widen.stripped}" def check(foundCls: Symbol): Boolean = if (!isCheckable(foundCls)) true else if (!foundCls.derivesFrom(testCls)) { diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala new file mode 100644 index 000000000000..1415016fea26 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -0,0 +1,468 @@ +package dotty.tools +package dotc +package cc + +import core._ +import Phases.*, DenotTransformers.*, SymDenotations.* +import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* +import Types._ +import Symbols._ +import StdNames._ +import Decorators._ +import config.Printers.{capt, recheckr} +import ast.{tpd, untpd, Trees} +import NameKinds.{DocArtifactName, OuterSelectName, DefaultGetterName} +import Trees._ +import scala.util.control.NonFatal +import typer.ErrorReporting._ +import typer.RefChecks +import util.Spans.Span +import util.{SimpleIdentitySet, EqHashMap, SrcPos} +import util.Chars.* +import transform.* +import transform.SymUtils.* +import scala.collection.mutable +import reporting._ +import dotty.tools.backend.jvm.DottyBackendInterface.symExtensions +import CaptureSet.{CompareResult, withCaptureSetsExplained} + +object CheckCaptures: + import ast.tpd.* + + case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer: Env): + def isOpen = !captured.isAlwaysEmpty && !isBoxed + + final class SubstParamsMap(from: BindingType, to: List[Type])(using Context) + extends ApproximatingTypeMap: + def apply(tp: Type): Type = tp match + case tp: ParamRef => + if tp.binder == from then to(tp.paramNum) else tp + case tp: NamedType => + if tp.prefix `eq` NoPrefix then tp + else tp.derivedSelect(apply(tp.prefix)) + case _: ThisType => + tp + case _ => + mapOver(tp) + + /** Check that a @retains annotation only mentions references that can be tracked + * This check is performed at Typer. + */ + def checkWellformed(ann: Tree)(using Context): Unit = + for elem <- retainedElems(ann) do + elem.tpe match + case ref: CaptureRef => + if !ref.canBeTracked then + report.error(em"$elem cannot be tracked since it is not a parameter or a local variable", elem.srcPos) + case tpe => + report.error(em"$tpe is not a legal type for a capture set", elem.srcPos) + + /** If `tp` is a capturing type, check that all references it mentions have non-empty + * capture sets. + * This check is performed after capture sets are computed in phase cc. + */ + def checkWellformedPost(tp: Type, pos: SrcPos)(using Context): Unit = tp match + case CapturingType(parent, refs, _) => + for ref <- refs.elems do + if ref.captureSetOfInfo.elems.isEmpty then + report.error(em"$ref cannot be tracked since its capture set is empty", pos) + else if parent.captureSet.accountsFor(ref) then + report.warning(em"redundant capture: $parent already accounts for $ref", pos) + case _ => + + def checkWellformedPost(ann: Tree)(using Context): Unit = + /** The lists `elems(i) :: prev.reerse :: elems(0),...,elems(i-1),elems(i+1),elems(n)` + * where `n == elems.length-1`, i <- 0..n`. + */ + def choices(prev: List[Tree], elems: List[Tree]): List[List[Tree]] = elems match + case Nil => Nil + case elem :: elems => + List(elem :: (prev reverse_::: elems)) ++ choices(elem :: prev, elems) + for case first :: others <- choices(Nil, retainedElems(ann)) do + val firstRef = first.toCaptureRef + val remaining = CaptureSet(others.map(_.toCaptureRef)*) + if remaining.accountsFor(firstRef) then + report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos) + + private inline val disallowGlobal = true + +class CheckCaptures extends Recheck: + thisPhase => + + import ast.tpd.* + import CheckCaptures.* + + def phaseName: String = "cc" + override def isEnabled(using Context) = ctx.settings.Ycc.value + + def newRechecker()(using Context) = CaptureChecker(ctx) + + override def run(using Context): Unit = + checkOverrides.traverse(ctx.compilationUnit.tpdTree) + super.run + + def checkOverrides = new TreeTraverser: + def traverse(t: Tree)(using Context) = + t match + case t: Template => + // ^^^ TODO: Can we avoid doing overrides checks twice? + // We need to do them here since only at this phase CaptureTypes are relevant + // But maybe we can then elide the check during the RefChecks phase if -Ycc is set? + RefChecks.checkAllOverrides(ctx.owner.asClass) + case _ => + traverseChildren(t) + + class CaptureChecker(ictx: Context) extends Rechecker(ictx): + import ast.tpd.* + + override def transformType(tp: Type, inferred: Boolean)(using Context): Type = + + def addInnerVars(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + tp.derivedAppliedType(tycon, args.map(addVars(_, boxed = true))) + case tp @ RefinedType(core, rname, rinfo) => + val rinfo1 = addVars(rinfo) + if defn.isFunctionType(tp) then + rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + else + tp.derivedRefinedType(addInnerVars(core), rname, rinfo1) + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(addVars(_)), + resType = addVars(tp.resType)) + case tp: PolyType => + tp.derivedLambdaType( + resType = addVars(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addVars(tp.resType)) + case _ => + tp + + def addFunctionRefinements(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + if defn.isNonRefinedFunction(tp) then + MethodType.companion( + isContextual = defn.isContextFunctionClass(tycon.classSymbol), + isErased = defn.isErasedFunctionClass(tycon.classSymbol) + )(args.init, addFunctionRefinements(args.last)) + .toFunctionType(isJava = false, alwaysDependent = true) + .showing(i"add function refinement $tp --> $result", capt) + else + tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_))) + case tp @ RefinedType(core, rname, rinfo) if !defn.isFunctionType(tp) => + tp.derivedRefinedType( + addFunctionRefinements(core), rname, addFunctionRefinements(rinfo)) + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = addFunctionRefinements(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addFunctionRefinements(tp.resType)) + case _ => + tp + + /** Refine a possibly applied class type C where the class has tracked parameters + * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } + * where CV_1, ..., CV_n are fresh capture sets. + */ + def addCaptureRefinements(tp: Type): Type = tp.stripped match + case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass => + val cls = tp.typeSymbol.asClass + cls.paramGetters.foldLeft(tp) { (core, getter) => + if getter.termRef.isTracked then + val getterType = tp.memberInfo(getter).strippedDealias + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + .showing(i"add capture refinement $tp --> $result", capt) + else + core + } + case _ => + tp + + def addVars(tp: Type, boxed: Boolean = false): Type = + var tp1 = addInnerVars(tp) + val tp2 = addCaptureRefinements(tp1) + if tp1.canHaveInferredCapture + then CapturingType(tp2, CaptureSet.Var(), boxed) + else tp2 + + if inferred then + val cleanup = new TypeMap: + def apply(t: Type) = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case _ => + mapOver(t) + addVars(addFunctionRefinements(cleanup(tp))) + .showing(i"reinfer $tp --> $result", capt) + else + val addBoxes = new TypeTraverser: + def setBoxed(t: Type) = t match + case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => + annot.tree.setBoxedCapturing() + case _ => + + def traverse(t: Type) = + t match + case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => + args.foreach(setBoxed) + case TypeBounds(lo, hi) => + setBoxed(lo); setBoxed(hi) + case _ => + traverseChildren(t) + end addBoxes + + addBoxes.traverse(tp) + tp + end transformType + + private def interpolator(using Context) = new TypeTraverser: + override def traverse(t: Type) = + t match + case CapturingType(parent, refs: CaptureSet.Var, _) => + if variance < 0 then capt.println(i"solving $t") + refs.solve(variance) + traverse(parent) + case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionOrPolyType(t) => + traverse(rinfo) + case tp: TypeVar => + case tp: TypeRef => + traverse(tp.prefix) + case _ => + traverseChildren(t) + + private def interpolateVarsIn(tpt: Tree)(using Context): Unit = + if tpt.isInstanceOf[InferredTypeTree] then + interpolator.traverse(knownType(tpt)) + .showing(i"solved vars in ${knownType(tpt)}", capt) + + private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null) + + private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap() + def capturedVars(sym: Symbol)(using Context) = + myCapturedVars.getOrElseUpdate(sym, + if sym.ownersIterator.exists(_.isTerm) then CaptureSet.Var() + else CaptureSet.empty) + + def markFree(sym: Symbol, pos: SrcPos)(using Context): Unit = + if sym.exists then + val ref = sym.termRef + def recur(env: Env): Unit = + if env.isOpen && env.owner != sym.enclosure then + capt.println(i"Mark $sym with cs ${ref.captureSet} free in ${env.owner}") + checkElem(ref, env.captured, pos) + if env.owner.isConstructor then + if env.outer.owner != sym.enclosure then recur(env.outer.outer) + else recur(env.outer) + if ref.isTracked then recur(curEnv) + + def includeCallCaptures(sym: Symbol, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + var targetSet = capturedVars(sym) + if !targetSet.isAlwaysEmpty && sym.enclosure == ownEnclosure then + targetSet = targetSet.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def includeBoxedCaptures(tp: Type, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + val targetSet = tp.boxedCaptured.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def assertSub(cs1: CaptureSet, cs2: CaptureSet)(using Context) = + assert(cs1.subCaptures(cs2, frozen = false).isOK, i"$cs1 is not a subset of $cs2") + + def checkElem(elem: CaptureRef, cs: CaptureSet, pos: SrcPos)(using Context) = + val res = elem.singletonCaptureSet.subCaptures(cs, frozen = false) + if !res.isOK then + report.error(i"$elem cannot be referenced here; it is not included in allowed capture set ${res.blocking}", pos) + + def checkSubset(cs1: CaptureSet, cs2: CaptureSet, pos: SrcPos)(using Context) = + val res = cs1.subCaptures(cs2, frozen = false) + if !res.isOK then + report.error(i"references $cs1 are not all included in allowed capture set ${res.blocking}", pos) + + override def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + val cs = capturedVars(tree.meth.symbol) + recheckr.println(i"typing closure $tree with cvs $cs") + super.recheckClosure(tree, pt).capturing(cs) + .showing(i"rechecked $tree, $result", capt) + + override def recheckIdent(tree: Ident)(using Context): Type = + markFree(tree.symbol, tree.srcPos) + if tree.symbol.is(Method) then includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckIdent(tree) + + override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + try super.recheckValDef(tree, sym) + finally + if !sym.is(Param) then + // parameters with inferred types belong to anonymous methods. We need to wait + // for more info from the context, so we cannot interpolate. Note that we cannot + // expect to have all necessary info available at the point where the anonymous + // function is compiled since we do not propagate expected types into blocks. + interpolateVarsIn(tree.tpt) + + override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val saved = curEnv + val localSet = capturedVars(sym) + if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv) + try super.recheckDefDef(tree, sym) + finally + interpolateVarsIn(tree.tpt) + curEnv = saved + + override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = + for param <- cls.paramGetters do + if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then + report.error( + "Implementation restriction: Class parameter with non-empty capture set must be a `val`", + param.srcPos) + val saved = curEnv + val localSet = capturedVars(cls) + if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv) + try super.recheckClassDef(tree, impl, cls) + finally curEnv = saved + + /** First half: Refine the type of a constructor call `new C(t_1, ..., t_n)` + * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked + * parameters of C and T_1, ..., T_m are the types of the corresponding arguments. + * + * Second half: union of all capture sets of arguments to tracked parameters. + */ + private def addParamArgRefinements(core: Type, argTypes: List[Type], cls: ClassSymbol)(using Context): (Type, CaptureSet) = + cls.paramGetters.lazyZip(argTypes).foldLeft((core, CaptureSet.empty: CaptureSet)) { (acc, refine) => + val (core, allCaptures) = acc + val (getter, argType) = refine + if getter.termRef.isTracked then + (RefinedType(core, getter.name, argType), allCaptures ++ argType.captureSet) + else + (core, allCaptures) + } + + /** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`. + * This means: + * - Instantiate result type with actual arguments + * - If call is to a constructor: + * - remember types of arguments corresponding to tracked + * parameters in refinements. + * - add capture set of instantiated class to capture set of result type. + */ + override def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + val ownType = + if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType) + else mt.resType + if sym.isConstructor then + val cls = sym.owner.asClass + val (refined, cs) = addParamArgRefinements(ownType, argTypes, cls) + refined.capturing(cs ++ capturedVars(cls) ++ capturedVars(sym)) + .showing(i"constr type $mt with $argTypes%, % in $cls = $result", capt) + else ownType + + def recheckByNameArg(tree: Tree, pt: Type)(using Context): Type = + val closureDef(mdef) = tree + val arg = mdef.rhs + val localSet = CaptureSet.Var() + curEnv = Env(mdef.symbol, localSet, isBoxed = false, curEnv) + val result = + try + inContext(ctx.withOwner(mdef.symbol)) { + recheckStart(arg, pt).capturing(localSet) + } + finally curEnv = curEnv.outer + recheckFinish(result, arg, pt) + + override def recheckApply(tree: Apply, pt: Type)(using Context): Type = + if tree.symbol == defn.cbnArg then + recheckByNameArg(tree.args(0), pt) + else + includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckApply(tree, pt) + + override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + val res = super.recheck(tree, pt) + if tree.isTerm then + includeBoxedCaptures(res, tree.srcPos) + res + + override def checkUnit(unit: CompilationUnit)(using Context): Unit = + withCaptureSetsExplained { + super.checkUnit(unit) + PostRefinerCheck.traverse(unit.tpdTree) + if ctx.settings.YccDebug.value then + show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing + } + + def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = + if disallowGlobal then + tree match + case LambdaTypeTree(_, restpt) => + checkNotGlobal(restpt, allArgs*) + case _ => + for ref <- knownType(tree).captureSet.elems do + val isGlobal = ref match + case ref: TermRef => + ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) + case _ => false + val what = if ref.isRootCapability then "universal" else "global" + if isGlobal then + val notAllowed = i" is not allowed to capture the $what capability $ref" + def msg = tree match + case tree: InferredTypeTree => + i"""inferred type argument ${knownType(tree)}$notAllowed + | + |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + case _ => s"type argument$notAllowed" + report.error(msg, tree.srcPos) + + object PostRefinerCheck extends TreeTraverser: + def traverse(tree: Tree)(using Context) = + tree match + case _: InferredTypeTree => + case tree: TypeTree if !tree.span.isZeroExtent => + knownType(tree).foreachPart( + checkWellformedPost(_, tree.srcPos)) + knownType(tree).foreachPart { + case AnnotatedType(_, annot) => + checkWellformedPost(annot.tree) + case _ => + } + case tree1 @ TypeApply(fn, args) if disallowGlobal => + for arg <- args do + //println(i"checking $arg in $tree: ${knownType(tree).captureSet}") + checkNotGlobal(arg, args*) + case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] => + val sym = t.symbol + val isLocal = + sym.ownersIterator.exists(_.isTerm) + || sym.accessBoundary(defn.RootClass).isContainedIn(sym.topLevelClass) + + // The following classes of definitions need explicit capture types ... + if !isLocal // ... since external capture types are not inferred + || sym.owner.is(Trait) // ... since we do OverridingPairs checking before capture inference + || sym.allOverriddenSymbols.nonEmpty // ... since we do override checking before capture inference + then + val inferred = knownType(t.tpt) + def checkPure(tp: Type) = tp match + case CapturingType(_, refs, _) if !refs.elems.isEmpty => + val resultStr = if t.isInstanceOf[DefDef] then " result" else "" + report.error( + em"""Non-local $sym cannot have an inferred$resultStr type + |$inferred + |with non-empty capture set $refs. + |The type needs to be declared explicitly.""", t.srcPos) + case _ => + inferred.foreachPart(checkPure, StopAt.Static) + case _ => + traverseChildren(tree) + + def postRefinerCheck(tree: tpd.Tree)(using Context): Unit = + PostRefinerCheck.traverse(tree) + + end CaptureChecker +end CheckCaptures diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index be38221ef167..c46c6d7c06cd 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -76,9 +76,8 @@ object Checking { } for (arg, which, bound) <- TypeOps.boundsViolations(args, boundss, instantiate, app) do report.error( - showInferred(DoesNotConformToBound(arg.tpe, which, bound), - app, tpt), - arg.srcPos.focus) + showInferred(DoesNotConformToBound(arg.tpe, which, bound), app, tpt), + arg.srcPos.focus) /** Check that type arguments `args` conform to corresponding bounds in `tl` * Note: This does not check the bounds of AppliedTypeTrees. These @@ -312,6 +311,7 @@ object Checking { case AndType(tp1, tp2) => isInteresting(tp1) || isInteresting(tp2) case OrType(tp1, tp2) => isInteresting(tp1) && isInteresting(tp2) case _: RefinedOrRecType | _: AppliedType => true + case tp: AnnotatedType => isInteresting(tp.parent) case _ => false } diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 17df9c93f9a9..108d07cb43f5 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -14,6 +14,7 @@ import Decorators._ import config.Printers.{gadts, typr, debug} import annotation.tailrec import reporting._ +import cc.{CapturingType, derivedCapturingType} import collection.mutable import scala.annotation.internal.sharable @@ -130,8 +131,8 @@ object Inferencing { couldInstantiateTypeVar(parent, applied) case tp: AndOrType => couldInstantiateTypeVar(tp.tp1, applied) || couldInstantiateTypeVar(tp.tp2, applied) - case AnnotatedType(tp, _) => - couldInstantiateTypeVar(tp, applied) + case tp: AnnotatedType => + couldInstantiateTypeVar(tp.parent, applied) case _ => false @@ -538,6 +539,7 @@ object Inferencing { case tp: RefinedType => tp.derivedRefinedType(captureWildcards(tp.parent), tp.refinedName, tp.refinedInfo) case tp: RecType => tp.derivedRecType(captureWildcards(tp.parent)) case tp: LazyRef => captureWildcards(tp.ref) + case CapturingType(parent, refs, _) => tp.derivedCapturingType(captureWildcards(parent), refs) case tp: AnnotatedType => tp.derivedAnnotatedType(captureWildcards(tp.parent), tp.annot) case _ => tp } @@ -726,6 +728,7 @@ trait Inferencing { this: Typer => if !argType.isSingleton then argType = SkolemType(argType) argType <:< tvar case _ => + () // scala-meta complains if this is missing, but I could not mimimize further end constrainIfDependentParamRef } @@ -740,4 +743,3 @@ trait Inferencing { this: Typer => enum IfBottom: case ok, fail, flip - diff --git a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala index 9ababe3e5f07..c1ec71e1d0f8 100644 --- a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala +++ b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala @@ -269,7 +269,7 @@ object RefChecks { * TODO This still needs to be cleaned up; the current version is a straight port of what was there * before, but it looks too complicated and method bodies are far too large. */ - private def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { + def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { val self = clazz.thisType val upwardsSelf = upwardsThisType(clazz) var hasErrors = false diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 7b67d828ddcb..0202f5b9d025 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -15,6 +15,7 @@ import ProtoTypes._ import collection.mutable import reporting._ import Checking.{checkNoPrivateLeaks, checkNoWildcard} +import cc.CaptureSet trait TypeAssigner { import tpd.* @@ -191,6 +192,14 @@ trait TypeAssigner { if tpe.isError then tpe else errorType(ex"$whatCanNot be accessed as a member of $pre$where.$whyNot", pos) + def processAppliedType(tree: untpd.Tree, tp: Type)(using Context): Type = tp match + case AppliedType(tycon, args) => + val constr = tycon.typeSymbol + if constr == defn.andType then AndType(args(0), args(1)) + else if constr == defn.orType then OrType(args(0), args(1), soft = false) + else tp + case _ => tp + /** Type assignment method. Each method takes as parameters * - an untpd.Tree to which it assigns a type, * - typed child trees it needs to access to cpmpute that type, @@ -288,8 +297,12 @@ trait TypeAssigner { val ownType = fn.tpe.widen match { case fntpe: MethodType => if (sameLength(fntpe.paramInfos, args) || ctx.phase.prev.relaxedTyping) - if (fntpe.isResultDependent) safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) - else fntpe.resultType + if fntpe.isCaptureDependent then + fntpe.resultType.substParams(fntpe, args.tpes) + else if fntpe.isResultDependent then + safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) + else + fntpe.resultType else errorType(i"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos) case t => @@ -461,11 +474,10 @@ trait TypeAssigner { assert(!hasNamedArg(args) || ctx.reporter.errorsReported, tree) val tparams = tycon.tpe.typeParams val ownType = - if (sameLength(tparams, args)) - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.appliedTo(args.tpes) - else wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + if !sameLength(tparams, args) then + wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + else + processAppliedType(tree, tycon.tpe.appliedTo(args.tpes)) tree.withType(ownType) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 02f3f6b4f164..bb871654ea5e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -49,6 +49,7 @@ import transform.TypeUtils._ import reporting._ import Nullables._ import NullOpsDecorator._ +import cc.CheckCaptures import config.Config import scala.annotation.constructorOnly @@ -1173,7 +1174,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => mapOver(t) } - val pt1 = pt.stripTypeVar.dealias.normalized + val pt1 = pt.strippedDealias.normalized if (pt1 ne pt1.dropDependentRefinement) && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType) then @@ -2596,6 +2597,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer registerNowarn(annot1, tree) val arg1 = typed(tree.arg, pt) if (ctx.mode is Mode.Type) { + if annot1.symbol.maybeOwner == defn.RetainsAnnot then + CheckCaptures.checkWellformed(annot1) if arg1.isType then assignType(cpy.Annotated(tree)(arg1, annot1), arg1, annot1) else diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index 45ee3652fe16..1fac0dac0913 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -15,6 +15,7 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A def toList: List[Elem] + def iterator: Iterator[Elem] final def isEmpty: Boolean = size == 0 @@ -59,6 +60,7 @@ object SimpleIdentitySet { def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty def /: [A, E <: AnyRef](z: A)(f: (A, E) => A): A = z def toList = Nil + def iterator = Iterator.empty } private class Set1[+Elem <: AnyRef](x0: AnyRef) extends SimpleIdentitySet[Elem] { @@ -76,6 +78,7 @@ object SimpleIdentitySet { def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(z, x0.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: Nil + def iterator = Iterator.single(x0.asInstanceOf[Elem]) } private class Set2[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef) extends SimpleIdentitySet[Elem] { @@ -95,6 +98,10 @@ object SimpleIdentitySet { def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(2) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + } } private class Set3[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef, x2: AnyRef) extends SimpleIdentitySet[Elem] { @@ -125,6 +132,11 @@ object SimpleIdentitySet { def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]), x2.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: x2.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(3) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + case 2 => x2.asInstanceOf[Elem] + } } private class SetN[+Elem <: AnyRef](val xs: Array[AnyRef]) extends SimpleIdentitySet[Elem] { @@ -173,6 +185,7 @@ object SimpleIdentitySet { foreach(buf += _) buf.toList } + def iterator = xs.iterator.asInstanceOf[Iterator[Elem]] override def ++ [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] = that match { case that: SetN[?] => diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 7bb546150cd8..4f128f6444ab 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -39,6 +39,7 @@ class CompilationTests { compileFilesInDir("tests/pos-special/isInstanceOf", allowDeepSubtypes.and("-Xfatal-warnings")), compileFilesInDir("tests/new", defaultOptions.and("-source", "3.1")), // just to see whether 3.1 works compileFilesInDir("tests/pos-scala2", scala2CompatMode), + compileFilesInDir("tests/pos-custom-args/captures", defaultOptions.and("-Ycc")), compileFilesInDir("tests/pos-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")), compileFilesInDir("tests/pos", defaultOptions.and("-Ysafe-init")), compileFilesInDir("tests/pos-deep-subtype", allowDeepSubtypes), @@ -139,6 +140,7 @@ class CompilationTests { compileFilesInDir("tests/neg-custom-args/allow-deep-subtypes", allowDeepSubtypes), compileFilesInDir("tests/neg-custom-args/explicit-nulls", defaultOptions.and("-Yexplicit-nulls")), compileFilesInDir("tests/neg-custom-args/no-experimental", defaultOptions.and("-Yno-experimental")), + compileFilesInDir("tests/neg-custom-args/captures", defaultOptions.and("-Ycc")), compileDir("tests/neg-custom-args/impl-conv", defaultOptions.and("-Xfatal-warnings", "-feature")), compileDir("tests/neg-custom-args/i13946", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), @@ -182,6 +184,7 @@ class CompilationTests { compileFile("tests/neg-custom-args/deptypes.scala", defaultOptions.and("-language:experimental.dependent")), compileFile("tests/neg-custom-args/matchable.scala", defaultOptions.and("-Xfatal-warnings", "-source", "future")), compileFile("tests/neg-custom-args/i7314.scala", defaultOptions.and("-Xfatal-warnings", "-source", "future")), + compileFile("tests/neg-custom-args/capt-wf.scala", defaultOptions.and("-Ycc", "-Xfatal-warnings")), compileFile("tests/neg-custom-args/feature-shadowing.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileDir("tests/neg-custom-args/hidden-type-errors", defaultOptions.and("-explain")), compileFile("tests/neg-custom-args/i13026.scala", defaultOptions.and("-print-lines")), diff --git a/library/src-bootstrapped/scala/Retains.scala b/library/src-bootstrapped/scala/Retains.scala new file mode 100644 index 000000000000..f3bfa282a012 --- /dev/null +++ b/library/src-bootstrapped/scala/Retains.scala @@ -0,0 +1,6 @@ +package scala + +/** An annotation that indicates capture + */ +class retains(xs: Any*) extends annotation.StaticAnnotation + diff --git a/library/src-bootstrapped/scala/annotation/ability.scala b/library/src-bootstrapped/scala/annotation/ability.scala new file mode 100644 index 000000000000..8b327a2f8b02 --- /dev/null +++ b/library/src-bootstrapped/scala/annotation/ability.scala @@ -0,0 +1,9 @@ +package scala.annotation + +/** An annotation inidcating that a val should be tracked as its own ability. + * Example: + * + * @ability erased val canThrow: * = ??? + * ^^^ rename to capability + */ +class ability extends StaticAnnotation \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/Predef.scala b/library/src/scala/runtime/stdLibPatches/Predef.scala index 13dfc77ac60b..387096ab55c5 100644 --- a/library/src/scala/runtime/stdLibPatches/Predef.scala +++ b/library/src/scala/runtime/stdLibPatches/Predef.scala @@ -47,4 +47,5 @@ object Predef: */ extension [T](x: T | Null) inline def nn: x.type & T = scala.runtime.Scala3RunTime.nn(x) + end Predef diff --git a/tests/disabled/neg-custom-args/captures/capt-wf.scala b/tests/disabled/neg-custom-args/captures/capt-wf.scala new file mode 100644 index 000000000000..54fe545f443b --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/capt-wf.scala @@ -0,0 +1,19 @@ +// No longer valid +class C +type Cap = C @retains(*) +type Top = Any @retains(*) + +type T = (x: Cap) => List[String @retains(x)] => Unit // error +val x: (x: Cap) => Array[String @retains(x)] = ??? // error +val y = x + +def test: Unit = + def f(x: Cap) = // ok + val g = (xs: List[String @retains(x)]) => () + g + def f2(x: Cap)(xs: List[String @retains(x)]) = () + val x = f // error + val x2 = f2 // error + val y = f(C()) // ok + val y2 = f2(C()) // ok + () diff --git a/tests/disabled/neg-custom-args/captures/try2.check b/tests/disabled/neg-custom-args/captures/try2.check new file mode 100644 index 000000000000..c7b20d0f7c5e --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.check @@ -0,0 +1,38 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:31:32 ----------------------------------------- +31 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => Nothing + | Required: () => Nothing + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:45:2 ------------------------------------------ +45 | yy // error + | ^^ + | Found: (yy : List[(xx : (() => Int) retains canThrow)]) + | Required: List[() => Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:52:2 ------------------------------------------ +47 |val global = handle { +48 | (x: CanThrow[Exception]) => +49 | () => +50 | raise(new Exception)(using x) +51 | 22 +52 |} { // error + | ^ + | Found: (() => Int) retains canThrow + | Required: () => Int +53 | (ex: Exception) => () => 22 +54 |} + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try2.scala:24:28 -------------------------------------------------------------- +24 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the global capability (canThrow : *) +-- Error: tests/neg-custom-args/captures/try2.scala:36:11 -------------------------------------------------------------- +36 | val xx = handle { // error + | ^^^^^^ + |inferred type argument ((() => Int) retains canThrow) is not allowed to capture the global capability (canThrow : *) + | + |The inferred arguments are: [Exception, ((() => Int) retains canThrow)] diff --git a/tests/disabled/neg-custom-args/captures/try2.scala b/tests/disabled/neg-custom-args/captures/try2.scala new file mode 100644 index 000000000000..dd3cc890a197 --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.scala @@ -0,0 +1,55 @@ +// Retains syntax for classes not (yet?) supported +import language.experimental.erasedDefinitions +import annotation.ability + +@ability erased val canThrow: * = ??? + +class CanThrow[E <: Exception] extends Retains[canThrow.type] +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: List[() => Int] = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // error + +val global = handle { + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { // error + (ex: Exception) => () => 22 +} diff --git a/tests/disabled/pos/lazylist.scala b/tests/disabled/pos/lazylist.scala new file mode 100644 index 000000000000..be628113d2d8 --- /dev/null +++ b/tests/disabled/pos/lazylist.scala @@ -0,0 +1,51 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] + +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail: {*} LazyList[T] = xs() + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] = + LazyCons(x, () => xs().concat(that)) +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] = +// f(x).concat(xs().flatMap(f)) + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + def concat[U](that: {*} LazyList[U]): {that} LazyList[U] = that +// def flatMap[U](f: {*} Nothing => LazyList[U]): LazyList[U] = LazyNil + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: {cap1} LazyList[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyList[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyList[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, cap2, cap3} LazyList[Int] = ref4 \ No newline at end of file diff --git a/tests/neg/i9325.scala b/tests/neg-custom-args/allow-deep-subtypes/i9325.scala similarity index 100% rename from tests/neg/i9325.scala rename to tests/neg-custom-args/allow-deep-subtypes/i9325.scala diff --git a/tests/neg-custom-args/capt-wf.scala b/tests/neg-custom-args/capt-wf.scala new file mode 100644 index 000000000000..dc4d6a0d4bff --- /dev/null +++ b/tests/neg-custom-args/capt-wf.scala @@ -0,0 +1,35 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x1: {*} C = ??? // OK + val x2: {other} C = ??? // error: cs is empty + val s1 = () => "abc" + val x3: {s1} C = ??? // error: cs is empty + val x3a: () => String = s1 + val s2 = () => if x1 == null then "" else "abc" + val x4: {s2} C = ??? // OK + val x5: {c, c} C = ??? // error: redundant + val x6: {c} {c} C = ??? // error: redundant + val x7: {c} Cap = ??? // error: redundant + val x8: {*} {c} C = ??? // OK + val x9: {c, *} C = ??? // error: redundant + val x10: {*, c} C = ??? // error: redundant + + def even(n: Int): Boolean = if n == 0 then true else odd(n - 1) + def odd(n: Int): Boolean = if n == 1 then true else even(n - 1) + val e1 = even + val o1 = odd + + val y1: {e1} String = ??? // error cs is empty + val y2: {o1} String = ??? // error cs is empty + + lazy val ev: (Int => Boolean) = (n: Int) => + lazy val od: (Int => Boolean) = (n: Int) => + if n == 1 then true else ev(n - 1) + if n == 0 then true else od(n - 1) + val y3: {ev} String = ??? // error cs is empty + + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/bounded.scala b/tests/neg-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..dc2621e95a65 --- /dev/null +++ b/tests/neg-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: () => {c} Int => Int = r2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/boxmap.check b/tests/neg-custom-args/captures/boxmap.check new file mode 100644 index 000000000000..406077077af5 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:14:2 ---------------------------------------- +14 | () => b[Box[B]]((x: A) => box(f(x))) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {f} () => ? Box[B] + | Required: () => Box[B] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/boxmap.scala b/tests/neg-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..e335320ef9d4 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.scala @@ -0,0 +1,14 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): () => Box[B] = + () => b[Box[B]]((x: A) => box(f(x))) // error diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala new file mode 100644 index 000000000000..526cdc50952f --- /dev/null +++ b/tests/neg-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap) = + def f() = if cap1 == cap1 then g else g + def g(x: Int) = if cap2 == cap2 then 1 else x + def h(ff: => {cap2} Int => Int) = ff + h(f()) // error + + diff --git a/tests/neg-custom-args/captures/capt-box-env.scala b/tests/neg-custom-args/captures/capt-box-env.scala new file mode 100644 index 000000000000..e9743054076e --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box-env.scala @@ -0,0 +1,12 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error diff --git a/tests/neg-custom-args/captures/capt-box.scala b/tests/neg-custom-args/captures/capt-box.scala new file mode 100644 index 000000000000..317fc064ec0b --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box.scala @@ -0,0 +1,13 @@ +//import scala.retains +class C +type Cap = {*} C + +def test(x: Cap) = + + def foo(y: Cap) = if x == y then println() + + val x1 = foo + + val x2 = identity(x1) + + val x3: Cap => Unit = x2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-depfun.scala b/tests/neg-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b0beb92b313 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun.scala @@ -0,0 +1,7 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val dc: (({y, z} String) => {y, z} String) = ac(g()) // error diff --git a/tests/neg-custom-args/captures/capt-depfun2.scala b/tests/neg-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..874d753b048d --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,10 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc = ac(g()) // error: Needs explicit type Array[? >: String <: {y, z} String] + // This is a shortcoming of rechecking since the originally inferred + // type is `Array[String]` and the actual type after rechecking + // cannot be expressed as `Array[C String]` for any capture set C \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-env.scala b/tests/neg-custom-args/captures/capt-env.scala new file mode 100644 index 000000000000..84b4b57a7930 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-env.scala @@ -0,0 +1,13 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error + diff --git a/tests/neg-custom-args/captures/capt-test.scala b/tests/neg-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..0c536a280f5c --- /dev/null +++ b/tests/neg-custom-args/captures/capt-test.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: (CanThrow[E]) => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: Unit = + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } diff --git a/tests/neg-custom-args/captures/capt-wf-typer.scala b/tests/neg-custom-args/captures/capt-wf-typer.scala new file mode 100644 index 000000000000..5120e2b288d5 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-wf-typer.scala @@ -0,0 +1,10 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x7: {c} String = ??? // OK + val x8: String @retains(x7 + x7) = ??? // error + val x9: String @retains(foo) = ??? // error + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt1.check b/tests/neg-custom-args/captures/capt1.check new file mode 100644 index 000000000000..ce7c4833bf9c --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.check @@ -0,0 +1,46 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:3:2 ------------------------------------------ +3 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: () => C + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:6:2 ------------------------------------------ +6 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: Matchable + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:13:2 ----------------------------------------- +13 | def f(y: Int) = if x == null then y else y // error + | ^ + | Found: {x} Int => Int + | Required: Matchable +14 | f + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:20:2 ----------------------------------------- +20 | class F(y: Int) extends A: // error + | ^ + | Found: {x} A + | Required: A +21 | def m() = if x == null then y else y +22 | F(22) + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:25:2 ----------------------------------------- +25 | new A: // error + | ^ + | Found: {x} A + | Required: A +26 | def m() = if x == null then y else y + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ---------------------------------------- +31 | val z2 = h[() => Cap](() => x)(() => C()) // error + | ^^^^^^^ + | Found: {x} () => ? Cap + | Required: () => Cap + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/capt1.scala b/tests/neg-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..4da49c5f4f1e --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.scala @@ -0,0 +1,34 @@ +class C +def f(x: C @retains(*), y: C): () => C = + () => if x == null then y else y // error + +def g(x: C @retains(*), y: C): Matchable = + () => if x == null then y else y // error + +def h1(x: C @retains(*), y: C): Any = + def f() = if x == null then y else y + () => f() // ok + +def h2(x: C @retains(*)): Matchable = + def f(y: Int) = if x == null then y else y // error + f + +class A +type Cap = C @retains(*) + +def h3(x: Cap): A = + class F(y: Int) extends A: // error + def m() = if x == null then y else y + F(22) + +def h4(x: Cap, y: Int): A = + new A: // error + def m() = if x == null then y else y + +def foo() = + val x: C @retains(*) = ??? + def h[X](a: X)(b: X) = a + val z2 = h[() => Cap](() => x)(() => C()) // error + val z3 = h[(() => Cap) @retains(x)](() => x)(() => C()) // ok + val z4 = h[(() => Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3 + diff --git a/tests/neg-custom-args/captures/capt2.scala b/tests/neg-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..1eee53463f6d --- /dev/null +++ b/tests/neg-custom-args/captures/capt2.scala @@ -0,0 +1,9 @@ +//import scala.retains +class C +type Cap = {*} C + +def f1(c: Cap): (() => {c} C) = () => c // error, but would be OK under capture abbreciations for funciton types +def f2(c: Cap): ({c} () => C) = () => c // error + +def h5(x: Cap): () => C = + f1(x) // error diff --git a/tests/neg-custom-args/captures/capt3.scala b/tests/neg-custom-args/captures/capt3.scala new file mode 100644 index 000000000000..80b937276f73 --- /dev/null +++ b/tests/neg-custom-args/captures/capt3.scala @@ -0,0 +1,26 @@ +class C +type Cap = C @retains(*) + +def test1() = + val x: Cap = C() + val y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test2() = + val x: Cap = C() + def y = () => { x; () } + def z = y + z: (() => Unit) // error + +def test3() = + val x: Cap = C() + def y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test4() = + val x: Cap = C() + val y = () => { x; () } + def z = y + z: (() => Unit) // error diff --git a/tests/neg-custom-args/captures/cc1.scala b/tests/neg-custom-args/captures/cc1.scala new file mode 100644 index 000000000000..ebd983c58fe9 --- /dev/null +++ b/tests/neg-custom-args/captures/cc1.scala @@ -0,0 +1,4 @@ +object Test: + + def f[A <: Matchable @retains(*)](x: A): Matchable = x // error + diff --git a/tests/neg-custom-args/captures/classes.scala b/tests/neg-custom-args/captures/classes.scala new file mode 100644 index 000000000000..b87d21913d4e --- /dev/null +++ b/tests/neg-custom-args/captures/classes.scala @@ -0,0 +1,12 @@ +class B +type Cap = {*} B +class C0(n: Cap) // error: class parameter must be a `val`. + +class C(val n: Cap): + def foo(): {n} B = n + +def test(x: Cap, y: Cap) = + val c0 = C(x) + val c1: C = c0 // error + val c2 = if ??? then C(x) else /*identity*/(C(y)) // TODO: uncomment + val c3: {x} C { val n: {x, y} B } = c2 // error diff --git a/tests/neg-custom-args/captures/io.scala b/tests/neg-custom-args/captures/io.scala new file mode 100644 index 000000000000..17c22a2111e4 --- /dev/null +++ b/tests/neg-custom-args/captures/io.scala @@ -0,0 +1,22 @@ +sealed trait IO: + def puts(msg: Any): Unit = println(msg) + +def test1 = + val IO : IO @retains(*) = new IO {} + def foo = {IO; IO.puts("hello") } + val x : () => Unit = () => foo // error: Found: (() => Unit) retains IO; Required: () => Unit + +def test2 = + val IO : IO @retains(*) = new IO {} + def puts(msg: Any, io: IO @retains(*)) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + +type Capability[T] = T @retains(*) + +def test3 = + val IO : Capability[IO] = new IO {} + def puts(msg: Any, io: Capability[IO]) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check new file mode 100644 index 000000000000..3a80de9bdf16 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.check @@ -0,0 +1,42 @@ +-- [E163] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ---------------------------------------- +22 | def tail: {*} LazyList[Nothing] = ??? // error overriding + | ^ + | error overriding method tail in class LazyList of type => lazylists.LazyList[Nothing]; + | method tail of type => {*} lazylists.LazyList[Nothing] has incompatible type + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 ------------------------------------- +35 | val ref1c: LazyList[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} lazylists.LazyCons[Int]{xs: {cap1} () => {*} lazylists.LazyList[Int]}) + | Required: lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:37:36 ------------------------------------- +37 | val ref2c: {ref1} LazyList[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {ref1} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:39:36 ------------------------------------- +39 | val ref3c: {cap2} LazyList[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {cap2} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:41:48 ------------------------------------- +41 | val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap3, cap2, ref1, cap1} lazylists.LazyList[Int]) + | Required: {cap1, ref3, cap3} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/lazylist.scala:17:6 ----------------------------------------------------------- +17 | def tail = xs() // error: cannot have an inferred type + | ^^^^^^^^^^^^^^^ + | Non-local method tail cannot have an inferred result type + | {*} lazylists.LazyList[T] + | with non-empty capture set {*}. + | The type needs to be declared explicitly. diff --git a/tests/neg-custom-args/captures/lazylist.scala b/tests/neg-custom-args/captures/lazylist.scala new file mode 100644 index 000000000000..f7be43e8dc27 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.scala @@ -0,0 +1,41 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail = xs() // error: cannot have an inferred type + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail: {*} LazyList[Nothing] = ??? // error overriding + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: LazyList[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {ref1} LazyList[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {cap2} LazyList[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/lazyref.check b/tests/neg-custom-args/captures/lazyref.check new file mode 100644 index 000000000000..2affed020dec --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.check @@ -0,0 +1,28 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:19:28 -------------------------------------- +19 | val ref1c: LazyRef[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} LazyRef[Int]{elem: {cap1} () => Int}) + | Required: LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:21:35 -------------------------------------- +21 | val ref2c: {cap2} LazyRef[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap2} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:23:35 -------------------------------------- +23 | val ref3c: {ref1} LazyRef[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {ref1} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:25:35 -------------------------------------- +25 | val ref4c: {cap1} LazyRef[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap2, cap1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap1} LazyRef[Int] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazyref.scala b/tests/neg-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..1002f9685675 --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: LazyRef[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {cap2} LazyRef[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {ref1} LazyRef[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1} LazyRef[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check new file mode 100644 index 000000000000..bd95835c6525 --- /dev/null +++ b/tests/neg-custom-args/captures/try.check @@ -0,0 +1,25 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------ +28 | val b = handle[Exception, () => Nothing] { // error + | ^ + | Found: ? (x: CanThrow[Exception]) => {x} () => ? Nothing + | Required: CanThrow[Exception] => () => Nothing +29 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) +30 | } { + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try.scala:22:28 --------------------------------------------------------------- +22 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/try.scala:34:11 --------------------------------------------------------------- +34 | val xx = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] +-- Error: tests/neg-custom-args/captures/try.scala:46:13 --------------------------------------------------------------- +46 |val global = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala new file mode 100644 index 000000000000..804a16192be0 --- /dev/null +++ b/tests/neg-custom-args/captures/try.scala @@ -0,0 +1,53 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // OK + + +val global = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { + (ex: Exception) => () => 22 +} \ No newline at end of file diff --git a/tests/neg-custom-args/captures/try3.scala b/tests/neg-custom-args/captures/try3.scala new file mode 100644 index 000000000000..4fbb980b9e03 --- /dev/null +++ b/tests/neg-custom-args/captures/try3.scala @@ -0,0 +1,27 @@ +import java.io.IOException + +class CT[E] +type CanThrow[E] = {*} CT[E] +type Top = {*} Any + +def handle[E <: Exception, T <: Top](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +@main def Test: Int = + def f(a: Boolean) = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception diff --git a/tests/neg/multiLineOps.scala b/tests/neg/multiLineOps.scala index 8499cc9fe710..08a0a3925fd1 100644 --- a/tests/neg/multiLineOps.scala +++ b/tests/neg/multiLineOps.scala @@ -5,7 +5,7 @@ val x = 1 val b1 = { 22 * 22 // ok - */*one more*/22 // error: end of statement expected // error: not found: * + */*one more*/22 // error: end of statement expected } val b2: Boolean = { diff --git a/tests/neg/polymorphic-functions1.check b/tests/neg/polymorphic-functions1.check new file mode 100644 index 000000000000..86492e96dab5 --- /dev/null +++ b/tests/neg/polymorphic-functions1.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 --------------------------------------------- +1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error + | ^ + | Found: [T] => (x: Int) => Int + | Required: [T] => (x: T) => x.type + +longer explanation available when compiling with `-explain` diff --git a/tests/neg/polymorphic-functions1.scala b/tests/neg/polymorphic-functions1.scala new file mode 100644 index 000000000000..de887f3b8c50 --- /dev/null +++ b/tests/neg/polymorphic-functions1.scala @@ -0,0 +1 @@ +val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error diff --git a/tests/pos-custom-args/captures/bounded.scala b/tests/pos-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..fad0b50c2137 --- /dev/null +++ b/tests/pos-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: {c} () => {c} Int => Int = r2 \ No newline at end of file diff --git a/tests/pos-custom-args/captures/boxmap-paper.scala b/tests/pos-custom-args/captures/boxmap-paper.scala new file mode 100644 index 000000000000..ed8c648526d1 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap-paper.scala @@ -0,0 +1,38 @@ +infix type ==> [A, B] = {*} (A => B) + +type Cell[+T] = [K] => (T ==> K) => K + +def cell[T](x: T): Cell[T] = + [K] => (k: T ==> K) => k(x) + +def get[T](c: Cell[T]): T = c[T](identity) + +def map[A, B](c: Cell[A])(f: A ==> B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def pureMap[A, B](c: Cell[A])(f: A => B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def lazyMap[A, B](c: Cell[A])(f: A ==> B): {f} () => Cell[B] + = () => c[Cell[B]]((x: A) => cell(f(x))) + +trait IO: + def print(s: String): Unit + +def test(io: {*} IO) = + + val loggedOne: {io} () => Int = () => { io.print("1"); 1 } + + val c: Cell[{io} () => Int] + = cell[{io} () => Int](loggedOne) + + val g = (f: {io} () => Int) => + val x = f(); io.print(" + ") + val y = f(); io.print(s" = ${x + y}") + + val r = lazyMap[{io} () => Int, Unit](c)(f => g(f)) + val r2 = lazyMap[{io} () => Int, Unit](c)(g) + val r3 = lazyMap(c)(g) + val _ = r() + val _ = r2() + val _ = r3() diff --git a/tests/pos-custom-args/captures/boxmap.scala b/tests/pos-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..a0dcade2b179 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap.scala @@ -0,0 +1,20 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): (() => Box[B]) @retains(f) = + () => b[Box[B]]((x: A) => box(f(x))) + +def test[A <: Top, B <: Top] = + def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B) = + () => b[Box[B]]((x: A) => box(f(x))) + val x: (b: Box[A]) => (f: A ==> B) => (() => Box[B]) @retains(f) = lazymap[A, B] + () diff --git a/tests/pos-custom-args/captures/byname.scala b/tests/pos-custom-args/captures/byname.scala new file mode 100644 index 000000000000..917154079b36 --- /dev/null +++ b/tests/pos-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +class I + +def test(cap1: Cap, cap2: Cap): {cap1} I = + def f() = if cap1 == cap1 then I() else I() + def h(x: => {cap1} I) = x + h(f()) + diff --git a/tests/pos-custom-args/captures/capt-depfun.scala b/tests/pos-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b99eff32692 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun.scala @@ -0,0 +1,18 @@ +class C +type Cap = C @retains(*) + +type T = (x: Cap) => String @retains(x) + +val aa: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + +def f(y: Cap, z: Cap): String @retains(*) = + val a: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + val b = a(y) + val c: String @retains(y) = b + def g(): C @retains(y, z) = ??? + val d = a(g()) + + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val bc: (({y} String) => {y} String) = ac(y) + val dc: (String => {y, z} String) = ac(g()) + c diff --git a/tests/pos-custom-args/captures/capt-depfun2.scala b/tests/pos-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..17f98b4a1554 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,8 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc: Array[? >: String <: {y, z} String] = ac(g()) // needs to be inferred + val ec = ac(y) diff --git a/tests/pos-custom-args/captures/capt-test.scala b/tests/pos-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..f40bd2ff1746 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-test.scala @@ -0,0 +1,35 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + + val a4 = zs.map(identity) + val a4c: LIST[{d, y} Cap => Unit] = a4 diff --git a/tests/pos-custom-args/captures/capt0.scala b/tests/pos-custom-args/captures/capt0.scala new file mode 100644 index 000000000000..c8ff8a102856 --- /dev/null +++ b/tests/pos-custom-args/captures/capt0.scala @@ -0,0 +1,7 @@ +object Test: + + def test() = + val x: {*} Any = "abc" + val y: Object @scala.retains(x) = ??? + val z: Object @scala.retains(x, *) = y: Object @scala.retains(x) + diff --git a/tests/pos-custom-args/captures/capt1.scala b/tests/pos-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..14c0855544d4 --- /dev/null +++ b/tests/pos-custom-args/captures/capt1.scala @@ -0,0 +1,27 @@ +class C +type Cap = {*} C +def f1(c: Cap): {c} () => c.type = () => c // ok + +def f2: Int = + val g: {*} Boolean => Int = ??? + val x = g(true) + x + +def f3: Int = + def g: {*} Boolean => Int = ??? + def h = g + val x = g.apply(true) + x + +def foo() = + val x: {*} C = ??? + val y: {x} C = x + val x2: {x} () => C = ??? + val y2: {x} () => {x} C = x2 + + val z1: {*} () => Cap = f1(x) + def h[X](a: X)(b: X) = a + + val z2 = + if x == null then () => x else () => C() + x \ No newline at end of file diff --git a/tests/pos-custom-args/captures/capt2.scala b/tests/pos-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..e3d4cd67b30c --- /dev/null +++ b/tests/pos-custom-args/captures/capt2.scala @@ -0,0 +1,20 @@ +import scala.retains +class C +type Cap = C @retains(*) + +def test1() = + val y: {*} String = "" + def x: Object @retains(y) = y + +def test2() = + val x: Cap = C() + val y = () => { x; () } + def z: (() => Unit) @retains(x) = y + z: (() => Unit) @retains(x) + def z2: (() => Unit) @retains(y) = y + z2: (() => Unit) @retains(y) + val p: {*} () => String = () => "abc" + val q: {p} C = ??? + p: ({p} () => String) + + diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala new file mode 100644 index 000000000000..eedc95554b17 --- /dev/null +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -0,0 +1,21 @@ +object Test: + + class A + class B + class C + class CTC + type CT = CTC @retains(*) + + def test(ct: CT, dt: CT) = + + def x0: A => {ct} B = ??? + + def x1: A => B @retains(ct) = ??? + def x2: A => B => C @retains(ct) = ??? + def x3: A => () => B => C @retains(ct) = ??? + + def x4: (x: A @retains(ct)) => B => C = ??? + + def x5: A => (x: B @retains(ct)) => () => C @retains(dt) = ??? + def x6: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x, dt) = ??? + def x7: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x) = ??? \ No newline at end of file diff --git a/tests/pos-custom-args/captures/classes.scala b/tests/pos-custom-args/captures/classes.scala new file mode 100644 index 000000000000..f3d6e44b27ca --- /dev/null +++ b/tests/pos-custom-args/captures/classes.scala @@ -0,0 +1,34 @@ +class B +type Cap = {*} B +class C(val n: Cap): + this: ({n} C) => + def foo(): {n} B = n + + +def test(x: Cap, y: Cap, z: Cap) = + val c0 = C(x) + val c1: {x} C {val n: {x} B} = c0 + val d = c1.foo() + d: ({x} B) + + val c2 = if ??? then C(x) else C(y) + val c2a = identity(c2) + val c3: {x, y} C { val n: {x, y} B } = c2 + val d1 = c3.foo() + d1: B @retains(x, y) + + class Local: + + def this(a: Cap) = + this() + if a == z then println("?") + + val f = y + def foo = x + end Local + + val l = Local() + val l1: {x, y} Local = l + val l2 = Local(x) + val l3: {x, y, z} Local = l2 + diff --git a/tests/pos-custom-args/captures/iterators.scala b/tests/pos-custom-args/captures/iterators.scala new file mode 100644 index 000000000000..dd1067bcdc72 --- /dev/null +++ b/tests/pos-custom-args/captures/iterators.scala @@ -0,0 +1,23 @@ +package cctest + +abstract class Iterator[T]: + thisIterator => + + def hasNext: Boolean + def next: T + def map(f: {*} T => T): {f} Iterator[T] = new Iterator: + def hasNext = thisIterator.hasNext + def next = f(thisIterator.next) +end Iterator + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + val it = new Iterator[Int]: + private var ctr = 0 + def hasNext = ctr < 10 + def next = { ctr += 1; ctr } + + def f(x: Int): Int = if c == d then x else 10 + val it2 = it.map(f) diff --git a/tests/pos-custom-args/captures/lazyref.scala b/tests/pos-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..39748b00506b --- /dev/null +++ b/tests/pos-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: {cap1} LazyRef[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyRef[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyRef[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1, cap2} LazyRef[Int] = ref4 diff --git a/tests/pos-custom-args/captures/list-encoding.scala b/tests/pos-custom-args/captures/list-encoding.scala new file mode 100644 index 000000000000..74bc8bd2b099 --- /dev/null +++ b/tests/pos-custom-args/captures/list-encoding.scala @@ -0,0 +1,23 @@ +package listEncoding + +class Cap + +type Op[T, C] = + {*} (v: T) => {*} (s: C) => C + +type List[T] = + [C] => (op: Op[T, C]) => {op} (s: C) => C + +def nil[T]: List[T] = + [C] => (op: Op[T, C]) => (s: C) => s + +def cons[T](hd: T, tl: List[T]): List[T] = + [C] => (op: Op[T, C]) => (s: C) => op(hd)(tl(op)(s)) + +def foo(c: {*} Cap) = + def f(x: String @retains(c), y: String @retains(c)) = + cons(x, cons(y, nil)) + def g(x: String @retains(c), y: Any) = + cons(x, cons(y, nil)) + def h(x: String, y: Any @retains(c)) = + cons(x, cons(y, nil)) diff --git a/tests/pos-custom-args/captures/lists.scala b/tests/pos-custom-args/captures/lists.scala new file mode 100644 index 000000000000..139f885ec87a --- /dev/null +++ b/tests/pos-custom-args/captures/lists.scala @@ -0,0 +1,91 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + val z1 = zs.head + val z1c: {y, d} Cap => Unit = z1 + val ys1 = zs.tail + val y1 = ys1.head + + + def m1[A, B] = + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m1c: (f: {*} String => Int) => {f} LIST[String] => LIST[Int] = m1[String, Int] + + def m2 = [A, B] => + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m2c: [A, B] => (f: {*} A => B) => {f} LIST[A] => LIST[B] = m2 + + def eff[A](x: A) = if x == e then x else x + + val eff2 = [A] => (x: A) => if x == e then x else x + + val a0 = identity[{d, y} Cap => Unit] + val a0c: ({d, y} Cap => Unit) => {d, y} Cap => Unit = a0 + val a1 = zs.map[{d, y} Cap => Unit](a0) + val a1c: LIST[{d, y} Cap => Unit] = a1 + val a2 = zs.map[{d, y} Cap => Unit](identity[{d, y} Cap => Unit]) + val a2c: LIST[{d, y} Cap => Unit] = a2 + val a3 = zs.map(identity[{d, y} Cap => Unit]) + val a3c: LIST[{d, y} Cap => Unit] = a3 + val a4 = zs.map(identity) + val a4c: LIST[{d, c} Cap => Unit] = a4 + val a5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a5c: LIST[{d, c} Cap => Unit] = a5 + val a6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a6c: LIST[{d, c} Cap => Unit] = a6 + + val b0 = eff[{d, y} Cap => Unit] + val b0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = b0 + val b1 = zs.map[{d, y} Cap => Unit](a0) + val b1c: {e} LIST[{d, y} Cap => Unit] = b1 + val b2 = zs.map[{d, y} Cap => Unit](eff[{d, y} Cap => Unit]) + val b2c: {e} LIST[{d, y} Cap => Unit] = b2 + val b3 = zs.map(eff[{d, y} Cap => Unit]) + val b3c: {e} LIST[{d, y} Cap => Unit] = b3 + val b4 = zs.map(eff) + val b4c: {e} LIST[{d, c} Cap => Unit] = b4 + val b5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b5c: {e} LIST[{d, c} Cap => Unit] = b5 + val b6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b6c: {e} LIST[{d, c} Cap => Unit] = b6 + + val c0 = eff2[{d, y} Cap => Unit] + val c0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = c0 + val c1 = zs.map[{d, y} Cap => Unit](a0) + val c1c: {e} LIST[{d, y} Cap => Unit] = c1 + val c2 = zs.map[{d, y} Cap => Unit](eff2[{d, y} Cap => Unit]) + val c2c: {e} LIST[{d, y} Cap => Unit] = c2 + val c3 = zs.map(eff2[{d, y} Cap => Unit]) + val c3c: {e} LIST[{d, y} Cap => Unit] = c3 + diff --git a/tests/pos-custom-args/captures/pairs.scala b/tests/pos-custom-args/captures/pairs.scala new file mode 100644 index 000000000000..4f23a086a075 --- /dev/null +++ b/tests/pos-custom-args/captures/pairs.scala @@ -0,0 +1,33 @@ + +class C +type Cap = {*} C + +object Generic: + + class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 + +object Monomorphic: + + class Pair(val x: {*} Cap => Unit, val y: {*} Cap => Unit): + def fst = x + def snd = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 diff --git a/tests/pos-custom-args/captures/try.scala b/tests/pos-custom-args/captures/try.scala new file mode 100644 index 000000000000..a50eeabfb3a3 --- /dev/null +++ b/tests/pos-custom-args/captures/try.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R](op: (erased CanThrow[E]) => R)(handler: E => R): R = + erased val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +val _ = handle { (erased x) => + if true then + raise(new Exception)(using x) + 22 + else + 11 + } \ No newline at end of file diff --git a/tests/pos-custom-args/captures/try3.scala b/tests/pos-custom-args/captures/try3.scala new file mode 100644 index 000000000000..074517d8a9e5 --- /dev/null +++ b/tests/pos-custom-args/captures/try3.scala @@ -0,0 +1,51 @@ +import language.experimental.erasedDefinitions +import annotation.ability +import java.io.IOException + +class CT[-E] // variance is needed for correct rechecking inference +type CanThrow[E] = {*} CT[E] + +def handle[E <: Exception, T](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +def test1: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { + if !a then raise(IOException()) + (b: Boolean) => (_: CanThrow[IOException]) ?=> + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => (_: CanThrow[IOException]) ?=> -1 + } + handle { + val g = f(true) + g(false) // can raise an exception + f(true)(false) // can raise an exception + } { + ex => -1 + } +/* +def test2: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + handle { + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception + } { + ex => -1 + } +*/ \ No newline at end of file From 2941b6a176b72de41e3c5d724dd8deefa364af0e Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sat, 2 Oct 2021 12:19:30 +0200 Subject: [PATCH 04/24] Include capture sets of methods in enclosing class --- .../dotty/tools/dotc/typer/CheckCaptures.scala | 7 ++++++- tests/neg-custom-args/captures/nestedclass.check | 7 +++++++ tests/neg-custom-args/captures/nestedclass.scala | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 tests/neg-custom-args/captures/nestedclass.check create mode 100644 tests/neg-custom-args/captures/nestedclass.scala diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 1415016fea26..69f184bd79ef 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -263,7 +263,12 @@ class CheckCaptures extends Recheck: case ref: TermRef => ref.symbol.enclosure != ownEnclosure case _ => true } - checkSubset(targetSet, curEnv.captured, pos) + def includeIn(env: Env) = + capt.println(i"Include call capture $targetSet in ${env.owner}") + checkSubset(targetSet, env.captured, pos) + includeIn(curEnv) + if curEnv.owner.isTerm && curEnv.outer.owner.isClass then + includeIn(curEnv.outer) def includeBoxedCaptures(tp: Type, pos: SrcPos)(using Context): Unit = if curEnv.isOpen then diff --git a/tests/neg-custom-args/captures/nestedclass.check b/tests/neg-custom-args/captures/nestedclass.check new file mode 100644 index 000000000000..d3912d417a4c --- /dev/null +++ b/tests/neg-custom-args/captures/nestedclass.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/nestedclass.scala:15:15 ---------------------------------- +15 | val xsc: C = xs // error + | ^^ + | Found: (xs : {cap1} C) + | Required: C + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/nestedclass.scala b/tests/neg-custom-args/captures/nestedclass.scala new file mode 100644 index 000000000000..38adf7998868 --- /dev/null +++ b/tests/neg-custom-args/captures/nestedclass.scala @@ -0,0 +1,15 @@ +class CC +type Cap = {*} CC + +abstract class C: + def head: String + +def test(cap1: Cap, cap2: Cap) = + def f(x: String): String = if cap1 == cap1 then "" else "a" + def g(x: String): String = if cap2 == cap2 then "" else "a" + + val xs = + class Cimpl extends C: + def head = f("") + new Cimpl + val xsc: C = xs // error From 20cf766a23bc3fb55a76a9b33530bf623ee28c21 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Fri, 19 Nov 2021 13:39:21 +0100 Subject: [PATCH 05/24] Update iterators example --- tests/pos-custom-args/captures/iterators.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/pos-custom-args/captures/iterators.scala b/tests/pos-custom-args/captures/iterators.scala index dd1067bcdc72..1ac1bd96f6d7 100644 --- a/tests/pos-custom-args/captures/iterators.scala +++ b/tests/pos-custom-args/captures/iterators.scala @@ -1,11 +1,11 @@ package cctest abstract class Iterator[T]: - thisIterator => + thisIterator: ({*} Iterator[T]) => def hasNext: Boolean def next: T - def map(f: {*} T => T): {f} Iterator[T] = new Iterator: + def map(f: {*} T => T): {f, this} Iterator[T] = new Iterator: def hasNext = thisIterator.hasNext def next = f(thisIterator.next) end Iterator @@ -13,6 +13,10 @@ end Iterator class C type Cap = {*} C +def map[T, U](it: {*} Iterator[T], f: {*} T => U): {it, f} Iterator[U] = new Iterator: + def hasNext = it.hasNext + def next = f(it.next) + def test(c: Cap, d: Cap, e: Cap) = val it = new Iterator[Int]: private var ctr = 0 @@ -21,3 +25,4 @@ def test(c: Cap, d: Cap, e: Cap) = def f(x: Int): Int = if c == d then x else 10 val it2 = it.map(f) + val it3 = map(it, f) \ No newline at end of file From f18d951a3b61d3b7cc5d6d71b6fac2ae642ed9cd Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 5 Oct 2021 10:43:47 +0200 Subject: [PATCH 06/24] Add capture checks for mutable variables - Mutable variables have boxed types, so that we do not need to track them when computing capture sets of classes. - Mutable variable types cannot capture `*` in order to prevent scope extrusion. --- .../dotty/tools/dotc/transform/Recheck.scala | 15 ++++++-- .../tools/dotc/typer/CheckCaptures.scala | 33 ++++++++++------- tests/neg-custom-args/captures/vars.check | 17 +++++++++ tests/neg-custom-args/captures/vars.scala | 37 +++++++++++++++++++ tests/pos-custom-args/captures/vars.scala | 18 +++++++++ 5 files changed, 102 insertions(+), 18 deletions(-) create mode 100644 tests/neg-custom-args/captures/vars.check create mode 100644 tests/neg-custom-args/captures/vars.scala create mode 100644 tests/pos-custom-args/captures/vars.scala diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index a61b736a9cc1..55a57ede2e0a 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -76,7 +76,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val symd = sym.denot symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd) - def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp + def transformType(tp: Type, inferred: Boolean, boxed: Boolean = false)(using Context): Type = tp object transformTypes extends TreeTraverser: @@ -110,12 +110,19 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: mapOver(t) end SubstParams + private def transformTT(tree: TypeTree, boxed: Boolean)(using Context) = + transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree], boxed).rememberFor(tree) + def traverse(tree: Tree)(using Context) = - traverseChildren(tree) tree match - + case tree @ ValDef(_, tpt: TypeTree, _) if tree.symbol.is(Mutable) => + transformTT(tpt, boxed = true) + traverse(tree.rhs) + case _ => + traverseChildren(tree) + tree match case tree: TypeTree => - transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree]).rememberFor(tree) + transformTT(tree, boxed = false) case tree: ValOrDefDef => val sym = tree.symbol diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 69f184bd79ef..c550b514e3ff 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -115,7 +115,7 @@ class CheckCaptures extends Recheck: class CaptureChecker(ictx: Context) extends Rechecker(ictx): import ast.tpd.* - override def transformType(tp: Type, inferred: Boolean)(using Context): Type = + override def transformType(tp: Type, inferred: Boolean, boxed: Boolean)(using Context): Type = def addInnerVars(tp: Type): Type = tp match case tp @ AppliedType(tycon, args) => @@ -191,15 +191,15 @@ class CheckCaptures extends Recheck: apply(parent) case _ => mapOver(t) - addVars(addFunctionRefinements(cleanup(tp))) + addVars(addFunctionRefinements(cleanup(tp)), boxed) .showing(i"reinfer $tp --> $result", capt) else - val addBoxes = new TypeTraverser: - def setBoxed(t: Type) = t match - case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => - annot.tree.setBoxedCapturing() - case _ => + def setBoxed(t: Type) = t match + case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => + annot.tree.setBoxedCapturing() + case _ => + val addBoxes = new TypeTraverser: def traverse(t: Type) = t match case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => @@ -208,8 +208,8 @@ class CheckCaptures extends Recheck: setBoxed(lo); setBoxed(hi) case _ => traverseChildren(t) - end addBoxes + if boxed then setBoxed(tp) addBoxes.traverse(tp) tp end transformType @@ -417,12 +417,15 @@ class CheckCaptures extends Recheck: val what = if ref.isRootCapability then "universal" else "global" if isGlobal then val notAllowed = i" is not allowed to capture the $what capability $ref" - def msg = tree match - case tree: InferredTypeTree => - i"""inferred type argument ${knownType(tree)}$notAllowed - | - |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" - case _ => s"type argument$notAllowed" + def msg = + if allArgs.isEmpty then + i"type of mutable variable ${knownType(tree)}$notAllowed" + else tree match + case tree: InferredTypeTree => + i"""inferred type argument ${knownType(tree)}$notAllowed + | + |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + case _ => s"type argument$notAllowed" report.error(msg, tree.srcPos) object PostRefinerCheck extends TreeTraverser: @@ -463,6 +466,8 @@ class CheckCaptures extends Recheck: |The type needs to be declared explicitly.""", t.srcPos) case _ => inferred.foreachPart(checkPure, StopAt.Static) + case t: ValDef if t.symbol.is(Mutable) => + checkNotGlobal(t.tpt) case _ => traverseChildren(tree) diff --git a/tests/neg-custom-args/captures/vars.check b/tests/neg-custom-args/captures/vars.check new file mode 100644 index 000000000000..ceba6f5fb422 --- /dev/null +++ b/tests/neg-custom-args/captures/vars.check @@ -0,0 +1,17 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/vars.scala:11:24 ----------------------------------------- +11 | val z2c: () => Unit = z2 // error + | ^^ + | Found: (z2 : {x, cap1} () => Unit) + | Required: () => Unit + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/vars.scala:13:10 -------------------------------------------------------------- +13 | var a: {*} String => String = f // error + | ^^^^^^^^^^^^^^^^^^^ + | type of mutable variable box {*} String => String is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/vars.scala:27:2 --------------------------------------------------------------- +27 | local { cap3 => // error + | ^^^^^ + |inferred type argument {*} (x$0: ? String) => ? String is not allowed to capture the universal capability (* : Any) + | + |The inferred arguments are: [{*} (x$0: ? String) => ? String] diff --git a/tests/neg-custom-args/captures/vars.scala b/tests/neg-custom-args/captures/vars.scala new file mode 100644 index 000000000000..8c80eec75783 --- /dev/null +++ b/tests/neg-custom-args/captures/vars.scala @@ -0,0 +1,37 @@ +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap) = + def f(x: String): String = if cap1 == cap1 then "" else "a" + var x = f + val y = x + val z = () => if x("") == "" then "a" else "b" + val zc: {cap1} () => String = z + val z2 = () => { x = identity } + val z2c: () => Unit = z2 // error + + var a: {*} String => String = f // error + + def scope = + val cap3: Cap = CC() + def g(x: String): String = if cap3 == cap3 then "" else "a" + a = g + val gc = g + g + + val s = scope + val sc: {*} String => String = scope + + def local[T](op: Cap => T): T = op(CC()) + + local { cap3 => // error + def g(x: String): String = if cap3 == cap3 then "" else "a" + g + } + + class Ref: + var elem: {cap1} String => String = null + + val r = Ref() + r.elem = f + val fc = r.elem diff --git a/tests/pos-custom-args/captures/vars.scala b/tests/pos-custom-args/captures/vars.scala new file mode 100644 index 000000000000..aca56c55f386 --- /dev/null +++ b/tests/pos-custom-args/captures/vars.scala @@ -0,0 +1,18 @@ +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap) = + def f(x: String): String = if cap1 == cap1 then "" else "a" + var x = f + val y = x + val z = () => if x("") == "" then "a" else "b" + val zc: {cap1} () => String = z + val z2 = () => { x = identity } + val z2c: {cap1} () => Unit = z2 + + class Ref: + var elem: {cap1} String => String = null + + val r = Ref() + r.elem = f + val fc: {cap1} String => String = r.elem From a5d9d06b4a1e15802193b32245eca0eba7b722df Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 5 Oct 2021 12:23:41 +0200 Subject: [PATCH 07/24] Implement deep check for variables Scope extrusion can also happen for nested types, so we need to prevent {*} capturesets anywhere in the type of a mutable variable. --- .../tools/dotc/typer/CheckCaptures.scala | 54 ++++++++++++------- tests/neg-custom-args/captures/vars.check | 8 ++- tests/neg-custom-args/captures/vars.scala | 2 + 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index c550b514e3ff..f7a14f259f24 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -403,30 +403,46 @@ class CheckCaptures extends Recheck: show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing } + def checkNotGlobal(tree: Tree, tp: Type, allArgs: Tree*)(using Context): Unit = + for ref <-tp.captureSet.elems do + val isGlobal = ref match + case ref: TermRef => + ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) + case _ => false + if isGlobal then + val what = if ref.isRootCapability then "universal" else "global" + val notAllowed = i" is not allowed to capture the $what capability $ref" + def msg = + if allArgs.isEmpty then + i"type of mutable variable ${knownType(tree)}$notAllowed" + else tree match + case tree: InferredTypeTree => + i"""inferred type argument ${knownType(tree)}$notAllowed + | + |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + case _ => s"type argument$notAllowed" + report.error(msg, tree.srcPos) + def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = if disallowGlobal then tree match case LambdaTypeTree(_, restpt) => checkNotGlobal(restpt, allArgs*) case _ => - for ref <- knownType(tree).captureSet.elems do - val isGlobal = ref match - case ref: TermRef => - ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) - case _ => false - val what = if ref.isRootCapability then "universal" else "global" - if isGlobal then - val notAllowed = i" is not allowed to capture the $what capability $ref" - def msg = - if allArgs.isEmpty then - i"type of mutable variable ${knownType(tree)}$notAllowed" - else tree match - case tree: InferredTypeTree => - i"""inferred type argument ${knownType(tree)}$notAllowed - | - |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" - case _ => s"type argument$notAllowed" - report.error(msg, tree.srcPos) + checkNotGlobal(tree, knownType(tree), allArgs*) + + def checkNotGlobalDeep(tree: Tree)(using Context): Unit = + val checker = new TypeTraverser: + def traverse(tp: Type): Unit = tp match + case tp: TypeRef => + tp.info match + case TypeBounds(_, hi) => traverse(hi) + case _ => + case tp: TermRef => + case _ => + checkNotGlobal(tree, tp) + traverseChildren(tp) + checker.traverse(knownType(tree)) object PostRefinerCheck extends TreeTraverser: def traverse(tree: Tree)(using Context) = @@ -467,7 +483,7 @@ class CheckCaptures extends Recheck: case _ => inferred.foreachPart(checkPure, StopAt.Static) case t: ValDef if t.symbol.is(Mutable) => - checkNotGlobal(t.tpt) + checkNotGlobalDeep(t.tpt) case _ => traverseChildren(tree) diff --git a/tests/neg-custom-args/captures/vars.check b/tests/neg-custom-args/captures/vars.check index ceba6f5fb422..4eab5b6b2b3a 100644 --- a/tests/neg-custom-args/captures/vars.check +++ b/tests/neg-custom-args/captures/vars.check @@ -9,8 +9,12 @@ longer explanation available when compiling with `-explain` 13 | var a: {*} String => String = f // error | ^^^^^^^^^^^^^^^^^^^ | type of mutable variable box {*} String => String is not allowed to capture the universal capability (* : Any) --- Error: tests/neg-custom-args/captures/vars.scala:27:2 --------------------------------------------------------------- -27 | local { cap3 => // error +-- Error: tests/neg-custom-args/captures/vars.scala:14:9 --------------------------------------------------------------- +14 | var b: List[{*} String => String] = Nil // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + |type of mutable variable List[box {*} String => String] is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/vars.scala:29:2 --------------------------------------------------------------- +29 | local { cap3 => // error | ^^^^^ |inferred type argument {*} (x$0: ? String) => ? String is not allowed to capture the universal capability (* : Any) | diff --git a/tests/neg-custom-args/captures/vars.scala b/tests/neg-custom-args/captures/vars.scala index 8c80eec75783..4a58f79932b3 100644 --- a/tests/neg-custom-args/captures/vars.scala +++ b/tests/neg-custom-args/captures/vars.scala @@ -11,11 +11,13 @@ def test(cap1: Cap, cap2: Cap) = val z2c: () => Unit = z2 // error var a: {*} String => String = f // error + var b: List[{*} String => String] = Nil // error def scope = val cap3: Cap = CC() def g(x: String): String = if cap3 == cap3 then "" else "a" a = g + b = List(g) val gc = g g From 0b130e38ea6bd96e2a1d1dcbbb7edd3b4442cbe1 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 12 Dec 2021 19:15:29 +0100 Subject: [PATCH 08/24] Introduce @capability annotations This replaces the earlier @ability annotation. The mechanisms are different, though. @ability was an annotation on `val`s whereas `capability` is an annotation on classes. --- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 4 ++- .../dotty/tools/dotc/core/Definitions.scala | 2 +- .../src/dotty/tools/dotc/core/Types.scala | 1 - .../tools/dotc/typer/CheckCaptures.scala | 3 +- .../scala/annotation/ability.scala | 9 ------ library/src/scala/annotation/capability.scala | 13 +++++++++ tests/disabled/pos/lazylist.scala | 3 +- tests/neg-custom-args/captures/byname.scala | 3 +- .../captures/capt-capability.scala | 28 +++++++++++++++++++ tests/pos-custom-args/captures/try3.scala | 2 +- tests/pos-deep-subtype/i4036.scala | 12 -------- 11 files changed, 49 insertions(+), 31 deletions(-) delete mode 100644 library/src-bootstrapped/scala/annotation/ability.scala create mode 100644 library/src/scala/annotation/capability.scala create mode 100644 tests/pos-custom-args/captures/capt-capability.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index f8ca2f87e3c5..f81c7bbc7b4c 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -523,7 +523,9 @@ object CaptureSet: tp.captureSet case tp: TermParamRef => tp.captureSet - case _: TypeRef | _: TypeParamRef => + case _: TypeRef => + if tp.classSymbol.hasAnnotation(defn.CapabilityAnnot) then universal else empty + case _: TypeParamRef => empty case CapturingType(parent, refs, _) => recur(parent) ++ refs diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 689a81ab7a32..5368cc8d38e5 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -903,6 +903,7 @@ class Definitions { @tu lazy val BeanPropertyAnnot: ClassSymbol = requiredClass("scala.beans.BeanProperty") @tu lazy val BooleanBeanPropertyAnnot: ClassSymbol = requiredClass("scala.beans.BooleanBeanProperty") @tu lazy val BodyAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Body") + @tu lazy val CapabilityAnnot: ClassSymbol = requiredClass("scala.annotation.capability") @tu lazy val ChildAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Child") @tu lazy val ContextResultCountAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ContextResultCount") @tu lazy val ProvisionalSuperClassAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ProvisionalSuperClass") @@ -948,7 +949,6 @@ class Definitions { @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") @tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since") @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") - @tu lazy val AbilityAnnot: ClassSymbol = requiredClass("scala.annotation.ability") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index a64cab59b2aa..8a532c9bf7e5 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -2728,7 +2728,6 @@ object Types { def canBeTracked(using Context) = ((prefix eq NoPrefix) || symbol.is(ParamAccessor) && (prefix eq symbol.owner.thisType) - || symbol.hasAnnotation(defn.AbilityAnnot) || isRootCapability ) && !symbol.is(Method) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index f7a14f259f24..b842b76f84a3 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -406,8 +406,7 @@ class CheckCaptures extends Recheck: def checkNotGlobal(tree: Tree, tp: Type, allArgs: Tree*)(using Context): Unit = for ref <-tp.captureSet.elems do val isGlobal = ref match - case ref: TermRef => - ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) + case ref: TermRef => ref.isRootCapability case _ => false if isGlobal then val what = if ref.isRootCapability then "universal" else "global" diff --git a/library/src-bootstrapped/scala/annotation/ability.scala b/library/src-bootstrapped/scala/annotation/ability.scala deleted file mode 100644 index 8b327a2f8b02..000000000000 --- a/library/src-bootstrapped/scala/annotation/ability.scala +++ /dev/null @@ -1,9 +0,0 @@ -package scala.annotation - -/** An annotation inidcating that a val should be tracked as its own ability. - * Example: - * - * @ability erased val canThrow: * = ??? - * ^^^ rename to capability - */ -class ability extends StaticAnnotation \ No newline at end of file diff --git a/library/src/scala/annotation/capability.scala b/library/src/scala/annotation/capability.scala new file mode 100644 index 000000000000..15504acc3258 --- /dev/null +++ b/library/src/scala/annotation/capability.scala @@ -0,0 +1,13 @@ +package scala.annotation + +/** Marks an annotated class as a capabulity. + * If the annotation is present and -Ycc is set, any (possibly aliased + * or refined) instance of the class type is implicitly augmented with + * the universal capture set. Example + * + * @capability class CanThrow[T] + * + * THere, the capture set of any instance of `CanThrow` is assumed to be + * `{*}`. + */ +final class capability extends StaticAnnotation diff --git a/tests/disabled/pos/lazylist.scala b/tests/disabled/pos/lazylist.scala index be628113d2d8..958f4c35aaf0 100644 --- a/tests/disabled/pos/lazylist.scala +++ b/tests/disabled/pos/lazylist.scala @@ -34,8 +34,7 @@ object LazyNil extends LazyList[Nothing]: def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = xs.map(f) -class CC -type Cap = {*} CC +@annotation.capability class Cap def test(cap1: Cap, cap2: Cap, cap3: Cap) = def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala index 526cdc50952f..ef5876be2c11 100644 --- a/tests/neg-custom-args/captures/byname.scala +++ b/tests/neg-custom-args/captures/byname.scala @@ -1,5 +1,4 @@ -class CC -type Cap = {*} CC +@annotation.capability class Cap def test(cap1: Cap, cap2: Cap) = def f() = if cap1 == cap1 then g else g diff --git a/tests/pos-custom-args/captures/capt-capability.scala b/tests/pos-custom-args/captures/capt-capability.scala new file mode 100644 index 000000000000..41da15d288f1 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-capability.scala @@ -0,0 +1,28 @@ +import annotation.capability + +@capability class Cap +def f1(c: Cap): {c} () => c.type = () => c // ok + +def f2: Int = + val g: {*} Boolean => Int = ??? + val x = g(true) + x + +def f3: Int = + def g: {*} Boolean => Int = ??? + def h = g + val x = g.apply(true) + x + +def foo() = + val x: Cap = ??? + val y: Cap = x + val x2: {x} () => Cap = ??? + val y2: {x} () => Cap = x2 + + val z1: {*} () => Cap = f1(x) + def h[X](a: X)(b: X) = a + + val z2 = + if x == null then () => x else () => Cap() + x diff --git a/tests/pos-custom-args/captures/try3.scala b/tests/pos-custom-args/captures/try3.scala index 074517d8a9e5..b29ad2d4b352 100644 --- a/tests/pos-custom-args/captures/try3.scala +++ b/tests/pos-custom-args/captures/try3.scala @@ -1,5 +1,5 @@ import language.experimental.erasedDefinitions -import annotation.ability +import annotation.capability import java.io.IOException class CT[-E] // variance is needed for correct rechecking inference diff --git a/tests/pos-deep-subtype/i4036.scala b/tests/pos-deep-subtype/i4036.scala index 1784a9189fab..08ff248caf9d 100644 --- a/tests/pos-deep-subtype/i4036.scala +++ b/tests/pos-deep-subtype/i4036.scala @@ -11,9 +11,6 @@ object A { x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.type 1: v. @@ -25,16 +22,10 @@ object A { x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.T val u = new B { type T = Int } u: u. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. - x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x.x. @@ -55,8 +46,5 @@ object A { T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# - T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# - T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# - T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T# T#T#T#T#T#T#T#T#T#T#T#T#T#T#T#T } From 15133933c0c3febff3d8b8c39465c931cc72600f Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 12 Dec 2021 21:12:16 +0100 Subject: [PATCH 09/24] Fix typo and update semanticdb.expect --- library/src/scala/annotation/capability.scala | 2 +- tests/semanticdb/metac.expect | 83 ++----------------- 2 files changed, 9 insertions(+), 76 deletions(-) diff --git a/library/src/scala/annotation/capability.scala b/library/src/scala/annotation/capability.scala index 15504acc3258..98c6e15e023a 100644 --- a/library/src/scala/annotation/capability.scala +++ b/library/src/scala/annotation/capability.scala @@ -1,6 +1,6 @@ package scala.annotation -/** Marks an annotated class as a capabulity. +/** Marks an annotated class as a capability. * If the annotation is present and -Ycc is set, any (possibly aliased * or refined) instance of the class type is implicitly augmented with * the universal capture set. Example diff --git a/tests/semanticdb/metac.expect b/tests/semanticdb/metac.expect index 732e9b07aa85..d03def1b603c 100644 --- a/tests/semanticdb/metac.expect +++ b/tests/semanticdb/metac.expect @@ -319,7 +319,6 @@ Text => empty Language => Scala Symbols => 14 entries Occurrences => 30 entries -Synthetics => 2 entries Symbols: example/Anonymous# => class Anonymous extends Object { self: Anonymous & Anonymous => +6 decls } @@ -369,10 +368,6 @@ Occurrences: [18:6..18:9): foo <- example/Anonymous#foo. [18:16..18:19): Foo -> example/Anonymous#Foo# -Synthetics: -[10:2..10:9):locally => *[Unit] -[13:2..13:9):locally => *[Unit] - expect/AnonymousGiven.scala --------------------------- @@ -407,7 +402,6 @@ Text => empty Language => Scala Symbols => 109 entries Occurrences => 113 entries -Synthetics => 2 entries Symbols: classes/C1# => final class C1 extends AnyVal { self: C1 => +2 decls } @@ -635,10 +629,6 @@ Occurrences: [53:4..53:9): local -> local4 [53:10..53:11): + -> scala/Int#`+`(+4). -Synthetics: -[51:16..51:27):List(1).map => *[Int] -[51:16..51:20):List => *.apply[Int] - expect/Empty.scala ------------------ @@ -841,7 +831,7 @@ Text => empty Language => Scala Symbols => 181 entries Occurrences => 148 entries -Synthetics => 10 entries +Synthetics => 8 entries Symbols: _empty_/Enums. => final object Enums extends Object { self: Enums.type => +30 decls } @@ -1184,8 +1174,6 @@ Synthetics: [52:31..52:50):identity[Option[B]] => *[Function1[A, Option[B]]] [54:14..54:18):Some => *.apply[Some[Int]] [54:14..54:34):Some(Some(1)).unwrap => *(given_<:<_T_T[Option[Int]]) -[54:19..54:23):Some => *.apply[Int] -[54:28..54:34):unwrap => *[Some[Int], Int] [56:52..56:64):Enum[Planet] => *[Planet] expect/EtaExpansion.scala @@ -1198,7 +1186,7 @@ Text => empty Language => Scala Symbols => 3 entries Occurrences => 8 entries -Synthetics => 5 entries +Synthetics => 1 entries Symbols: example/EtaExpansion# => class EtaExpansion extends Object { self: EtaExpansion => +1 decls } @@ -1216,11 +1204,7 @@ Occurrences: [4:25..4:26): + -> java/lang/String#`+`(). Synthetics: -[3:2..3:13):Some(1).map => *[Int] -[3:2..3:6):Some => *.apply[Int] -[3:14..3:22):identity => *[Int] [4:2..4:18):List(1).foldLeft => *[String] -[4:2..4:6):List => *.apply[Int] expect/Example.scala -------------------- @@ -1370,7 +1354,7 @@ Text => empty Language => Scala Symbols => 13 entries Occurrences => 52 entries -Synthetics => 6 entries +Synthetics => 2 entries Symbols: example/ForComprehension# => class ForComprehension extends Object { self: ForComprehension => +1 decls } @@ -1442,10 +1426,6 @@ Occurrences: [41:6..41:7): f -> local10 Synthetics: -[4:9..4:13):List => *.apply[Int] -[5:9..5:13):List => *.apply[Int] -[10:9..10:13):List => *.apply[Int] -[11:9..11:13):List => *.apply[Int] [19:9..19:13):List => *.apply[Tuple2[Int, Int]] [33:9..33:13):List => *.apply[Tuple4[Int, Int, Int, Int]] @@ -1459,7 +1439,6 @@ Text => empty Language => Scala Symbols => 29 entries Occurrences => 65 entries -Synthetics => 3 entries Symbols: a/b/Givens. => final object Givens extends Object { self: Givens.type => +12 decls } @@ -1559,11 +1538,6 @@ Occurrences: [26:57..26:58): A -> a/b/Givens.foo().(A) [26:59..26:64): empty -> a/b/Givens.Monoid#empty(). -Synthetics: -[12:17..12:25):sayHello => *[Int] -[13:19..13:29):sayGoodbye => *[Int] -[14:18..14:27):saySoLong => *[Int] - expect/ImplicitConversion.scala ------------------------------- @@ -1963,7 +1937,6 @@ Text => empty Language => Scala Symbols => 6 entries Occurrences => 10 entries -Synthetics => 1 entries Symbols: example/Local# => class Local extends Object { self: Local => +2 decls } @@ -1985,9 +1958,6 @@ Occurrences: [4:25..4:26): a -> local1 [5:4..5:6): id -> local2 -Synthetics: -[5:4..5:6):id => *[Int] - expect/Locals.scala ------------------- @@ -1998,7 +1968,6 @@ Text => empty Language => Scala Symbols => 3 entries Occurrences => 6 entries -Synthetics => 1 entries Symbols: local0 => val local x: Int @@ -2013,9 +1982,6 @@ Occurrences: [5:4..5:8): List -> scala/package.List. [5:9..5:10): x -> local0 -Synthetics: -[5:4..5:8):List => *.apply[Int] - expect/MetacJava.scala ---------------------- @@ -2113,7 +2079,7 @@ Text => empty Language => Scala Symbols => 3 entries Occurrences => 80 entries -Synthetics => 2 entries +Synthetics => 1 entries Symbols: example/MethodUsages# => class MethodUsages extends Object { self: MethodUsages => +2 decls } @@ -2203,7 +2169,6 @@ Occurrences: [34:8..34:9): m -> example/Methods#m17.m(). Synthetics: -[13:2..13:6):m.m7 => *[Int] [13:2..13:26):m.m7(m, new m.List[Int]) => *(Int) expect/Methods.scala @@ -3056,7 +3021,7 @@ Text => empty Language => Scala Symbols => 52 entries Occurrences => 132 entries -Synthetics => 36 entries +Synthetics => 23 entries Symbols: example/Synthetic# => class Synthetic extends Object { self: Synthetic => +23 decls } @@ -3247,26 +3212,17 @@ Occurrences: [58:6..58:9): foo -> example/Synthetic#Contexts.foo(). Synthetics: -[5:2..5:13):List(1).map => *[Int] -[5:2..5:6):List => *.apply[Int] [6:2..6:18):Array.empty[Int] => intArrayOps(*) [7:2..7:8):"fooo" => augmentString(*) [10:13..10:24):"name:(.*)" => augmentString(*) -[11:17..11:25):LazyList => *.apply[Int] -[13:4..13:28):#:: 2 #:: LazyList.empty => *[Int] [13:8..13:28):2 #:: LazyList.empty => toDeferrer[Int](*) -[13:10..13:28):#:: LazyList.empty => *[Int] [13:14..13:28):LazyList.empty => toDeferrer[Nothing](*) [13:14..13:28):LazyList.empty => *[Nothing] -[15:25..15:33):LazyList => *.apply[Int] -[17:14..17:38):#:: 2 #:: LazyList.empty => *[Int] [17:18..17:38):2 #:: LazyList.empty => toDeferrer[Int](*) -[17:20..17:38):#:: LazyList.empty => *[Int] [17:24..17:38):LazyList.empty => toDeferrer[Nothing](*) [17:24..17:38):LazyList.empty => *[Nothing] [19:12..19:13):1 => intWrapper(*) [19:26..19:27):0 => intWrapper(*) -[19:46..19:50):x -> => *[Int] [19:46..19:47):x => ArrowAssoc[Int](*) [20:12..20:13):1 => intWrapper(*) [20:26..20:27):0 => intWrapper(*) @@ -3275,10 +3231,6 @@ Synthetics: [32:35..32:49):Array.empty[T] => *(evidence$1) [36:22..36:27):new F => orderingToOrdered[F](*) [36:22..36:27):new F => *(ordering) -[40:9..40:43):scala.concurrent.Future.successful => *[Int] -[41:9..41:43):scala.concurrent.Future.successful => *[Int] -[44:9..44:43):scala.concurrent.Future.successful => *[Int] -[45:9..45:43):scala.concurrent.Future.successful => *[Int] [51:24..51:30):foo(0) => *(x$1) [52:27..52:33):foo(0) => *(x) [55:6..55:12):foo(x) => *(x) @@ -3334,7 +3286,7 @@ Text => empty Language => Scala Symbols => 22 entries Occurrences => 46 entries -Synthetics => 7 entries +Synthetics => 2 entries Symbols: example/ValPattern# => class ValPattern extends Object { self: ValPattern => +14 decls } @@ -3409,13 +3361,8 @@ Occurrences: [40:10..40:18): rightVar -> local4 Synthetics: -[6:4..6:8):Some => *.apply[Int] [8:6..8:10):List => *.unapplySeq[Nothing] [8:11..8:15):Some => *.unapply[Nothing] -[12:4..12:8):Some => *.apply[Int] -[25:4..25:11):locally => *[Unit] -[28:8..28:12):Some => *.apply[Int] -[32:8..32:12):Some => *.apply[Int] expect/Vals.scala ----------------- @@ -3937,7 +3884,6 @@ Text => empty Language => Scala Symbols => 3 entries Occurrences => 6 entries -Synthetics => 1 entries Symbols: example/`local-file`# => class local-file extends Object { self: local-file => +1 decls } @@ -3952,9 +3898,6 @@ Occurrences: [5:4..5:9): local -> local0 [5:10..5:11): + -> scala/Int#`+`(+4). -Synthetics: -[3:2..3:9):locally => *[Int] - expect/nullary.scala -------------------- @@ -3965,7 +3908,6 @@ Text => empty Language => Scala Symbols => 17 entries Occurrences => 29 entries -Synthetics => 1 entries Symbols: _empty_/Concrete# => class Concrete extends NullaryTest[Int, List] { self: Concrete => +3 decls } @@ -4017,9 +3959,6 @@ Occurrences: [18:7..18:15): Concrete -> _empty_/Concrete# [18:17..18:25): nullary3 -> _empty_/Concrete#nullary3(). -Synthetics: -[13:17..13:21):List => *.apply[Int] - expect/recursion.scala ---------------------- @@ -4308,7 +4247,6 @@ Text => empty Language => Scala Symbols => 144 entries Occurrences => 225 entries -Synthetics => 1 entries Symbols: local0 => abstract method k => Int @@ -4683,9 +4621,6 @@ Occurrences: [119:32..119:38): Option -> scala/Option# [119:39..119:42): Int -> scala/Int# -Synthetics: -[68:20..68:24):@ann => *[Int] - expect/semanticdb-extract.scala ------------------------------- @@ -4696,7 +4631,7 @@ Text => empty Language => Scala Symbols => 18 entries Occurrences => 20 entries -Synthetics => 3 entries +Synthetics => 2 entries Symbols: _empty_/AnObject. => final object AnObject extends Object { self: AnObject.type => +6 decls } @@ -4741,7 +4676,6 @@ Occurrences: [16:20..16:23): Int -> scala/Int# Synthetics: -[11:2..11:6):List => *.apply[Int] [12:2..12:12):List.apply => *[Nothing] [13:2..13:14):List.`apply` => *[Nothing] @@ -4755,7 +4689,7 @@ Text => empty Language => Scala Symbols => 18 entries Occurrences => 43 entries -Synthetics => 2 entries +Synthetics => 1 entries Symbols: _empty_/MyProgram# => final class MyProgram extends Object { self: MyProgram => +2 decls } @@ -4823,6 +4757,5 @@ Occurrences: [7:30..7:33): foo -> _empty_/toplevel$package.foo(). Synthetics: -[5:40..5:60):(1 to times) foreach => *[Unit] [5:41..5:42):1 => intWrapper(*) From 81512d5dd250801c1dd9dbdbbf1052cadee4e0e1 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 12 Dec 2021 22:00:50 +0100 Subject: [PATCH 10/24] Disable MiMa checks --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 405225011cdd..5f488634956b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -114,9 +114,9 @@ jobs: ./project/scripts/sbt ";dist/pack; scala3-bootstrapped/compile; scala3-bootstrapped/test;sjsSandbox/run;sjsSandbox/test;sjsJUnitTests/test;sjsCompilerTests/test ;sbt-test/scripted scala2-compat/* ;configureIDE ;stdlib-bootstrapped/test:run ;stdlib-bootstrapped-tasty-tests/test; scala3-compiler-bootstrapped/scala3CompilerCoursierTest:test" ./project/scripts/bootstrapCmdTests - - name: MiMa - run: | - ./project/scripts/sbt ";scala3-interfaces/mimaReportBinaryIssues ;scala3-library-bootstrapped/mimaReportBinaryIssues ;scala3-library-bootstrappedJS/mimaReportBinaryIssues; tasty-core-bootstrapped/mimaReportBinaryIssues" +# - name: MiMa +# run: | +# ./project/scripts/sbt ";scala3-interfaces/mimaReportBinaryIssues ;scala3-library-bootstrapped/mimaReportBinaryIssues ;scala3-library-bootstrappedJS/mimaReportBinaryIssues; tasty-core-bootstrapped/mimaReportBinaryIssues" test_windows_fast: runs-on: [self-hosted, Windows] From 12f17a4f2dce9f7a80c52174e6214b349c635771 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Thu, 23 Dec 2021 19:36:53 +0100 Subject: [PATCH 11/24] Special rule for {this} in capture sets of class members Consider the lazylists.scala test in pos-custom-args/captures: ```scala class CC type Cap = {*} CC trait LazyList[+A]: this: ({*} LazyList[A]) => def isEmpty: Boolean def head: A def tail: {this} LazyList[A] object LazyNil extends LazyList[Nothing]: def isEmpty: Boolean = true def head = ??? def tail = ??? extension [A](xs: {*} LazyList[A]) def map[B](f: {*} A => B): {xs, f} LazyList[B] = class Mapped extends LazyList[B]: this: ({xs, f} Mapped) => def isEmpty = false def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK new Mapped ``` Without this commit, the second to last line is an error since the right hand side has capture set `{xs, f}` but the required capture set is `this`. To fix this, we widen the expected type of the rhs `xs.tail.map(f)` from `{this}` to `{this, f, xs}`. That is, we add the declared captures of the self type to the expected type. The soundness argument for doing this is as follows: Since `tail` does not have parameters, the only thing it could capture are references that the receiver `this` captures as well. So `xs` and `f` must come via `this`. For instance, if the receiver `xs` of `xs.tail` happens to be pure, then `xs.tail` is pure as well. On the other hand, in the neg test `lazylists1.scala` we add the following line to `Mapped`: ```scala def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error ``` Here, we cannot widen the expected type from `{this}` to `{this, xs, f}` since the result of concat refers to `f` independently of `this`, namely through its parameter `other`. Hence, if `ys: {f} LazyList[String]` then ``` LazyNil.concat(ys) ``` still refers to `f` even though `LazyNil` is pure. But if we would accept the definition of `concat` above, the type of `LazyNil.concat(ys)` would be `LazyList[String]`, which is unsound. The current implementation widens the expected type of class members if the class member does not have tracked parameters. We could potentially refine this to say we widen with all references in the expected type that are not subsumed by one of the parameter types. ## Changes: ### Refine rule for this widening We now widen the expected type of the right hand side of a class member as follows: Add all references of the declared type of this that are not subsumed by a capture set of a parameter type. ### Do expected type widening only in final classes Alex found a counter-example why this is required. See map5 in neg-customargs/captures/lazylists2.scala --- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 9 ++- .../dotty/tools/dotc/transform/Recheck.scala | 7 +- .../tools/dotc/typer/CheckCaptures.scala | 16 +++++ .../neg-custom-args/captures/lazylists1.check | 7 ++ .../neg-custom-args/captures/lazylists1.scala | 27 ++++++++ .../neg-custom-args/captures/lazylists2.check | 45 +++++++++++++ .../neg-custom-args/captures/lazylists2.scala | 64 +++++++++++++++++++ .../captures/lazylists-mono.scala | 27 ++++++++ .../pos-custom-args/captures/lazylists.scala | 42 ++++++++++++ .../pos-custom-args/captures/lazylists1.scala | 35 ++++++++++ 10 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 tests/neg-custom-args/captures/lazylists1.check create mode 100644 tests/neg-custom-args/captures/lazylists1.scala create mode 100644 tests/neg-custom-args/captures/lazylists2.check create mode 100644 tests/neg-custom-args/captures/lazylists2.scala create mode 100644 tests/pos-custom-args/captures/lazylists-mono.scala create mode 100644 tests/pos-custom-args/captures/lazylists.scala create mode 100644 tests/pos-custom-args/captures/lazylists1.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index f81c7bbc7b4c..cd8d67399d8d 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -49,7 +49,14 @@ sealed abstract class CaptureSet extends Showable: /** Is this capture set definitely non-empty? */ final def isNotEmpty: Boolean = !elems.isEmpty - /** Cast to variable. @pre: @isConst */ + /** Cast to Const. @pre: isConst */ + def asConst: Const = this match + case c: Const => c + case v: Var => + assert(v.isConst) + Const(v.elems) + + /** Cast to variable. @pre: !isConst */ def asVar: Var = assert(!isConst) asInstanceOf[Var] diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 55a57ede2e0a..924a444aeff4 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -203,12 +203,15 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: bindType def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = - if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) + if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym) def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } + inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) } + + def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = + recheck(tree, pt) def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index b842b76f84a3..1f30dd989f3a 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -322,6 +322,22 @@ class CheckCaptures extends Recheck: interpolateVarsIn(tree.tpt) curEnv = saved + override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = + val pt1 = pt match + case CapturingType(core, refs, _) + if sym.owner.isClass && !sym.owner.isExtensibleClass + && refs.elems.contains(sym.owner.thisType) => + val paramCaptures = + sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) => + val pcs = p.info.captureSet + (cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst + } + val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet + pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures)) + case _ => + pt + recheck(tree, pt1) + override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = for param <- cls.paramGetters do if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then diff --git a/tests/neg-custom-args/captures/lazylists1.check b/tests/neg-custom-args/captures/lazylists1.check new file mode 100644 index 000000000000..29291c8044c0 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists1.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 ----------------------------------- +25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {xs, f} LazyList[A] + | Required: {Mapped.this, xs} LazyList[A] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists1.scala b/tests/neg-custom-args/captures/lazylists1.scala new file mode 100644 index 000000000000..02c7cb4ff3e5 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists1.scala @@ -0,0 +1,27 @@ +class CC +type Cap = {*} CC + +trait LazyList[+A]: + this: ({*} LazyList[A]) => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +extension [A](xs: {*} LazyList[A]) + def map[B](f: {*} A => B): {xs, f} LazyList[B] = + final class Mapped extends LazyList[B]: + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) // OK + def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK + def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + new Mapped + diff --git a/tests/neg-custom-args/captures/lazylists2.check b/tests/neg-custom-args/captures/lazylists2.check new file mode 100644 index 000000000000..8e09dd26cccf --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists2.check @@ -0,0 +1,45 @@ +-- [E163] Declaration Error: tests/neg-custom-args/captures/lazylists2.scala:50:10 ------------------------------------- +50 | def tail: {xs, f} LazyList[B] = xs.tail.map(f) // error + | ^ + | error overriding method tail in trait LazyList of type => {Mapped.this} LazyList[B]; + | method tail of type => {xs, f} LazyList[B] has incompatible type + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:18:4 ------------------------------------ +18 | final class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: {f} LazyList[B] +19 | this: ({xs, f} Mapped) => +20 | def isEmpty = false +21 | def head: B = f(xs.head) +22 | def tail: {this} LazyList[B] = xs.tail.map(f) +23 | new Mapped + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:27:4 ------------------------------------ +27 | final class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: {xs} LazyList[B] +28 | this: ({xs, f} Mapped) => +29 | def isEmpty = false +30 | def head: B = f(xs.head) +31 | def tail: {this} LazyList[B] = xs.tail.map(f) +32 | new Mapped + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 ----------------------------------- +41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error + | ^^^^^^^^^^^^^^ + | Found: {f} LazyList[B] + | Required: {xs} LazyList[B] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 ----------------------------------- +59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error + | ^^^^^^^^^^^^^^ + | Found: {f} LazyList[B] + | Required: {Mapped.this} LazyList[B] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists2.scala b/tests/neg-custom-args/captures/lazylists2.scala new file mode 100644 index 000000000000..c31a1ae5d04f --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists2.scala @@ -0,0 +1,64 @@ +class CC +type Cap = {*} CC + +trait LazyList[+A]: + this: ({*} LazyList[A]) => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +extension [A](xs: {*} LazyList[A]) + def map[B](f: {*} A => B): {f} LazyList[B] = + final class Mapped extends LazyList[B]: // error + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) + new Mapped + + def map2[B](f: {*} A => B): {xs} LazyList[B] = + final class Mapped extends LazyList[B]: // error + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) + new Mapped + + def map3[B](f: {*} A => B): {xs} LazyList[B] = + final class Mapped extends LazyList[B]: + this: ({xs} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) // error + new Mapped + + def map4[B](f: {*} A => B): {xs} LazyList[B] = + final class Mapped extends LazyList[B]: + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {xs, f} LazyList[B] = xs.tail.map(f) // error + new Mapped + + def map5[B](f: {*} A => B): LazyList[B] = + class Mapped extends LazyList[B]: + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) // error + class Mapped2 extends Mapped: + this: Mapped => + new Mapped2 + + diff --git a/tests/pos-custom-args/captures/lazylists-mono.scala b/tests/pos-custom-args/captures/lazylists-mono.scala new file mode 100644 index 000000000000..82c44abf703a --- /dev/null +++ b/tests/pos-custom-args/captures/lazylists-mono.scala @@ -0,0 +1,27 @@ +class CC +type Cap = {*} CC + +//------------------------------------------------- + +def test(E: Cap) = + + trait LazyList[+A]: + protected def contents: {E} () => (A, {E} LazyList[A]) + def isEmpty: Boolean + def head: A = contents()._1 + def tail: {E} LazyList[A] = contents()._2 + + class LazyCons[+A](override val contents: {E} () => (A, {E} LazyList[A])) + extends LazyList[A]: + def isEmpty: Boolean = false + + object LazyNil extends LazyList[Nothing]: + def contents: {E} () => (Nothing, LazyList[Nothing]) = ??? + def isEmpty: Boolean = true + + extension [A](xs: {E} LazyList[A]) + def map[B](f: {E} A => B): {E} LazyList[B] = + if xs.isEmpty then LazyNil + else + val cons = () => (f(xs.head), xs.tail.map(f)) + LazyCons(cons) diff --git a/tests/pos-custom-args/captures/lazylists.scala b/tests/pos-custom-args/captures/lazylists.scala new file mode 100644 index 000000000000..17d5f8546edc --- /dev/null +++ b/tests/pos-custom-args/captures/lazylists.scala @@ -0,0 +1,42 @@ +class CC +type Cap = {*} CC + +trait LazyList[+A]: + this: ({*} LazyList[A]) => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +extension [A](xs: {*} LazyList[A]) + def map[B](f: {*} A => B): {xs, f} LazyList[B] = + final class Mapped extends LazyList[B]: + this: ({xs, f} Mapped) => + + def isEmpty = false + def head: B = f(xs.head) + def tail: {this} LazyList[B] = xs.tail.map(f) // OK + def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // OK + if xs.isEmpty then LazyNil + else new Mapped + +def test(cap1: Cap, cap2: Cap) = + def f(x: String): String = if cap1 == cap1 then "" else "a" + def g(x: String): String = if cap2 == cap2 then "" else "a" + + val xs = + class Initial extends LazyList[String]: + this: ({cap1} Initial) => + + def isEmpty = false + def head = f("") + def tail = LazyNil + new Initial + val xsc: {cap1} LazyList[String] = xs + val ys = xs.map(g) + val ysc: {cap1, cap2} LazyList[String] = ys diff --git a/tests/pos-custom-args/captures/lazylists1.scala b/tests/pos-custom-args/captures/lazylists1.scala new file mode 100644 index 000000000000..4c8006fb0e29 --- /dev/null +++ b/tests/pos-custom-args/captures/lazylists1.scala @@ -0,0 +1,35 @@ +class CC +type Cap = {*} CC + +trait LazyList[+A]: + this: ({*} LazyList[A]) => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +final class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + this: ({*} LazyList[T]) => + + def isEmpty = false + def head = x + def tail: {this} LazyList[T] = xs() + +extension [A](xs: {*} LazyList[A]) + def map[B](f: {*} A => B): {xs, f} LazyList[B] = + if xs.isEmpty then LazyNil + else LazyCons(f(xs.head), () => xs.tail.map(f)) + +def test(cap1: Cap, cap2: Cap) = + def f(x: String): String = if cap1 == cap1 then "" else "a" + def g(x: String): String = if cap2 == cap2 then "" else "a" + + val xs = LazyCons("", () => if f("") == f("") then LazyNil else LazyNil) + val xsc: {cap1} LazyList[String] = xs + val ys = xs.map(g) + val ysc: {cap1, cap2} LazyList[String] = ys From f134adb289cc535c44b1a07941698622dc0cecba Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sat, 25 Dec 2021 11:53:08 +0100 Subject: [PATCH 12/24] Pure function types -> and ?-> 1. Allow `->` and `?->` and function operators, treated like `=>` and `?=>`. 2. under -Ycc treat `->` and `?->` as immutable function types, whereas `A => B` is an alias of `{*} A -> B` and `A ?=> B` is an alias of `{*} A ?-> B`. Closures are unaffected, we still use `=>` for all closures where they are pure or not. Improve printing of capturing types Avoid explicit retains annotations also outside phase cc Generate "Impure" function aliases For every (possibly erased and/or context) function class XFunctionN, generate an alias ImpureXFunctionN in the Scala package defined as type ImpureXFunctionN[...] = {*} XFunctionN[...] Also: - Fix a bug in TypeComparer: glb has to test subCapture in a frozen state - Harden EventuallyCapturingType extractor to not crash on illegal capture sets - Cleanup transformation of inferred types --- compiler/src/dotty/tools/dotc/ast/untpd.scala | 6 +- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 20 +- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 6 + .../dotty/tools/dotc/cc/CapturingType.scala | 18 +- .../dotty/tools/dotc/core/Definitions.scala | 127 +++++++---- .../src/dotty/tools/dotc/core/Flags.scala | 4 +- .../src/dotty/tools/dotc/core/NameOps.scala | 40 ++-- .../src/dotty/tools/dotc/core/StdNames.scala | 2 + .../dotty/tools/dotc/core/TypeComparer.scala | 2 +- .../src/dotty/tools/dotc/core/Types.scala | 6 +- .../dotty/tools/dotc/parsing/Parsers.scala | 63 +++--- .../tools/dotc/printing/PlainPrinter.scala | 13 +- .../tools/dotc/printing/RefinedPrinter.scala | 52 +++-- .../tools/dotc/typer/CheckCaptures.scala | 206 ++++++++++++------ .../dotty/tools/dotc/typer/Implicits.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 9 +- .../quoted/runtime/impl/QuotesImpl.scala | 2 +- tests/neg-custom-args/capt-wf.scala | 6 +- tests/neg-custom-args/captures/bounded.scala | 4 +- tests/neg-custom-args/captures/boxmap.check | 8 +- tests/neg-custom-args/captures/boxmap.scala | 10 +- tests/neg-custom-args/captures/byname.scala | 2 +- .../captures/capt-box-env.scala | 5 +- tests/neg-custom-args/captures/capt-box.scala | 6 +- tests/neg-custom-args/captures/capt-env.scala | 2 +- tests/neg-custom-args/captures/capt1.check | 14 +- tests/neg-custom-args/captures/capt1.scala | 8 +- tests/neg-custom-args/captures/capt2.scala | 6 +- tests/neg-custom-args/captures/capt3.scala | 8 +- tests/neg-custom-args/captures/io.scala | 6 +- tests/neg-custom-args/captures/lazylist.check | 2 +- tests/neg-custom-args/captures/lazylist.scala | 6 +- .../neg-custom-args/captures/lazylists1.scala | 2 +- .../neg-custom-args/captures/lazylists2.scala | 10 +- tests/neg-custom-args/captures/lazyref.check | 8 +- tests/neg-custom-args/captures/lazyref.scala | 8 +- tests/neg-custom-args/captures/try.check | 14 +- tests/neg-custom-args/captures/try.scala | 2 +- tests/neg-custom-args/captures/vars.check | 24 +- tests/neg-custom-args/captures/vars.scala | 14 +- tests/pos-custom-args/captures/bounded.scala | 4 +- .../captures/boxmap-paper.scala | 23 +- tests/pos-custom-args/captures/boxmap.scala | 14 +- .../captures/capt-capability.scala | 12 +- .../captures/capt-depfun.scala | 12 +- .../captures/capt-depfun2.scala | 2 +- .../pos-custom-args/captures/capt-test.scala | 8 +- tests/pos-custom-args/captures/capt1.scala | 12 +- tests/pos-custom-args/captures/capt2.scala | 12 +- .../pos-custom-args/captures/cc-expand.scala | 16 +- .../pos-custom-args/captures/impurefun.scala | 8 + .../captures/lazylists-mono.scala | 8 +- .../pos-custom-args/captures/lazylists.scala | 2 +- .../pos-custom-args/captures/lazylists1.scala | 10 +- tests/pos-custom-args/captures/lazyref.scala | 11 +- .../captures/list-encoding.scala | 4 +- tests/pos-custom-args/captures/lists.scala | 87 ++++---- tests/pos-custom-args/captures/pairs.scala | 13 +- tests/pos-custom-args/captures/try.scala | 4 +- tests/pos-custom-args/captures/try3.scala | 5 +- tests/pos-custom-args/captures/vars.scala | 11 +- tests/pos/i12723.scala | 8 +- tests/pos/impurefun.scala | 4 + 63 files changed, 595 insertions(+), 448 deletions(-) create mode 100644 tests/pos-custom-args/captures/impurefun.scala create mode 100644 tests/pos/impurefun.scala diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index b9960cbb4652..740b2e3d9ab8 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -70,13 +70,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case class InterpolatedString(id: TermName, segments: List[Tree])(implicit @constructorOnly src: SourceFile) extends TermTree - /** A function type */ + /** A function type or closure */ case class Function(args: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree { override def isTerm: Boolean = body.isTerm override def isType: Boolean = body.isType } - /** A function type with `implicit`, `erased`, or `given` modifiers */ + /** A function type or closure with `implicit`, `erased`, or `given` modifiers */ class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile) extends Function(args, body) @@ -217,6 +217,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case class Transparent()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Transparent) case class Infix()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Infix) + + case class Impure()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Impure) } /** Modifiers and annotations for definitions diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 09064314b1bf..4c201d7edf54 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -17,9 +17,13 @@ def retainedElems(tree: Tree)(using Context): List[Tree] = tree match case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems case _ => Nil +class IllegalCaptureRef(tpe: Type) extends Exception + extension (tree: Tree) - def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef] + def toCaptureRef(using Context): CaptureRef = tree.tpe match + case ref: CaptureRef => ref + case tpe => throw IllegalCaptureRef(tpe) def toCaptureSet(using Context): CaptureSet = tree.getAttachment(Captures) match @@ -59,20 +63,6 @@ extension (tp: Type) def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty - def canHaveInferredCapture(using Context): Boolean = tp match - case tp: TypeRef if tp.symbol.isClass => - !tp.symbol.isValueClass && tp.symbol != defn.AnyClass - case _: TypeVar | _: TypeParamRef => - false - case tp: TypeProxy => - tp.superType.canHaveInferredCapture - case tp: AndType => - tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture - case tp: OrType => - tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture - case _ => - false - def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match case CapturingType(parent, _, _) => parent.stripCapturing diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index cd8d67399d8d..82e5e6e14a4b 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -56,6 +56,12 @@ sealed abstract class CaptureSet extends Showable: assert(v.isConst) Const(v.elems) + final def isUniversal(using Context) = + elems.exists { + case ref: TermRef => ref.symbol == defn.captureRoot + case _ => false + } + /** Cast to variable. @pre: !isConst */ def asVar: Var = assert(!isConst) diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala index 2eeb1ff41b72..738e746d0178 100644 --- a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -12,10 +12,22 @@ object CapturingType: else AnnotatedType(parent, CaptureAnnotation(refs, boxed)) def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = - if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then + if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp) + else None + +end CapturingType + +object EventuallyCapturingType: + + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = + if tp.annot.symbol == defn.RetainsAnnot then tp.annot match case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed)) - case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + case ann => + try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + catch case ex: IllegalCaptureRef => None else None -end CapturingType +end EventuallyCapturingType + + diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 5368cc8d38e5..576bdf6d1b95 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -8,7 +8,7 @@ import Flags._, Scopes._, Decorators._, NameOps._, Periods._, NullOpsDecorator._ import unpickleScala2.Scala2Unpickler.ensureConstructor import scala.collection.mutable import collection.mutable -import Denotations.SingleDenotation +import Denotations.{SingleDenotation, staticRef} import util.{SimpleIdentityMap, SourceFile, NoSource} import typer.ImportInfo.RootRef import Comments.CommentsContext @@ -89,7 +89,7 @@ class Definitions { * * FunctionN traits follow this template: * - * trait FunctionN[T0,...T{N-1}, R] extends Object { + * trait FunctionN[-T0,...-T{N-1}, +R] extends Object { * def apply($x0: T0, ..., $x{N_1}: T{N-1}): R * } * @@ -99,46 +99,65 @@ class Definitions { * * ContextFunctionN traits follow this template: * - * trait ContextFunctionN[T0,...,T{N-1}, R] extends Object { + * trait ContextFunctionN[-T0,...,-T{N-1}, +R] extends Object { * def apply(using $x0: T0, ..., $x{N_1}: T{N-1}): R * } * * ErasedFunctionN traits follow this template: * - * trait ErasedFunctionN[T0,...,T{N-1}, R] extends Object { + * trait ErasedFunctionN[-T0,...,-T{N-1}, +R] extends Object { * def apply(erased $x0: T0, ..., $x{N_1}: T{N-1}): R * } * * ErasedContextFunctionN traits follow this template: * - * trait ErasedContextFunctionN[T0,...,T{N-1}, R] extends Object { + * trait ErasedContextFunctionN[-T0,...,-T{N-1}, +R] extends Object { * def apply(using erased $x0: T0, ..., $x{N_1}: T{N-1}): R * } * * ErasedFunctionN and ErasedContextFunctionN erase to Function0. + * + * EffXYZFunctionN afollow this template: + * + * type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R] */ - def newFunctionNTrait(name: TypeName): ClassSymbol = { + private def newFunctionNType(name: TypeName): Symbol = { + val impure = name.startsWith("Impure") val completer = new LazyType { def complete(denot: SymDenotation)(using Context): Unit = { - val cls = denot.asClass.classSymbol - val decls = newScope val arity = name.functionArity - val paramNamePrefix = tpnme.scala ++ str.NAME_JOIN ++ name ++ str.EXPAND_SEPARATOR - val argParamRefs = List.tabulate(arity) { i => - enterTypeParam(cls, paramNamePrefix ++ "T" ++ (i + 1).toString, Contravariant, decls).typeRef - } - val resParamRef = enterTypeParam(cls, paramNamePrefix ++ "R", Covariant, decls).typeRef - val methodType = MethodType.companion( - isContextual = name.isContextFunction, - isImplicit = false, - isErased = name.isErasedFunction) - decls.enter(newMethod(cls, nme.apply, methodType(argParamRefs, resParamRef), Deferred)) - denot.info = - ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls) + if impure then + val argParamNames = List.tabulate(arity)(tpnme.syntheticTypeParamName) + val argVariances = List.fill(arity)(Contravariant) + val underlyingName = name.asSimpleName.drop(6) + val underlyingClass = ScalaPackageVal.requiredClass(underlyingName) + denot.info = TypeAlias( + HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)( + tl => List.fill(arity + 1)(TypeBounds.empty), + tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs), + CaptureSet.universal, boxed = false) + )) + else + val cls = denot.asClass.classSymbol + val decls = newScope + val paramNamePrefix = tpnme.scala ++ str.NAME_JOIN ++ name ++ str.EXPAND_SEPARATOR + val argParamRefs = List.tabulate(arity) { i => + enterTypeParam(cls, paramNamePrefix ++ "T" ++ (i + 1).toString, Contravariant, decls).typeRef + } + val resParamRef = enterTypeParam(cls, paramNamePrefix ++ "R", Covariant, decls).typeRef + val methodType = MethodType.companion( + isContextual = name.isContextFunction, + isImplicit = false, + isErased = name.isErasedFunction) + decls.enter(newMethod(cls, nme.apply, methodType(argParamRefs, resParamRef), Deferred)) + denot.info = + ClassInfo(ScalaPackageClass.thisType, cls, ObjectType :: Nil, decls) } } - val flags = Trait | NoInits - newPermanentClassSymbol(ScalaPackageClass, name, flags, completer) + if impure then + newPermanentSymbol(ScalaPackageClass, name, EmptyFlags, completer) + else + newPermanentClassSymbol(ScalaPackageClass, name, Trait | NoInits, completer) } private def newMethod(cls: ClassSymbol, name: TermName, info: Type, flags: FlagSet = EmptyFlags): TermSymbol = @@ -212,7 +231,7 @@ class Definitions { val cls = ScalaPackageVal.moduleClass.asClass cls.info.decls.openForMutations.useSynthesizer( name => - if (name.isTypeName && name.isSyntheticFunction) newFunctionNTrait(name.asTypeName) + if (name.isTypeName && name.isSyntheticFunction) newFunctionNType(name.asTypeName) else NoSymbol) cls } @@ -1289,39 +1308,55 @@ class Definitions { @tu lazy val TupleType: Array[TypeRef] = mkArityArray("scala.Tuple", MaxTupleArity, 1) + /** Cached function types of arbitary arities. + * Function types are created on demand with newFunctionNTrait, which is + * called from a synthesizer installed in ScalaPackageClass. + */ private class FunType(prefix: String): private var classRefs: Array[TypeRef] = new Array(22) + def apply(n: Int): TypeRef = while n >= classRefs.length do val classRefs1 = new Array[TypeRef](classRefs.length * 2) Array.copy(classRefs, 0, classRefs1, 0, classRefs.length) classRefs = classRefs1 + val funName = s"scala.$prefix$n" if classRefs(n) == null then - classRefs(n) = requiredClassRef(prefix + n.toString) + classRefs(n) = + if prefix.startsWith("Impure") + then staticRef(funName.toTypeName).symbol.typeRef + else requiredClassRef(funName) classRefs(n) - - private val erasedContextFunType = FunType("scala.ErasedContextFunction") - private val contextFunType = FunType("scala.ContextFunction") - private val erasedFunType = FunType("scala.ErasedFunction") - private val funType = FunType("scala.Function") - - def FunctionClass(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): Symbol = - ( if isContextual && isErased then erasedContextFunType(n) - else if isContextual then contextFunType(n) - else if isErased then erasedFunType(n) - else funType(n) - ).symbol.asClass + end FunType + + private def funTypeIdx(isContextual: Boolean, isErased: Boolean, isImpure: Boolean): Int = + (if isContextual then 1 else 0) + + (if isErased then 2 else 0) + + (if isImpure then 4 else 0) + + private val funTypeArray: IArray[FunType] = + val arr = Array.ofDim[FunType](8) + val choices = List(false, true) + for contxt <- choices; erasd <- choices; impure <- choices do + var str = "Function" + if contxt then str = "Context" + str + if erasd then str = "Erased" + str + if impure then str = "Impure" + str + arr(funTypeIdx(contxt, erasd, impure)) = FunType(str) + IArray.unsafeFromArray(arr) + + def FunctionSymbol(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): Symbol = + funTypeArray(funTypeIdx(isContextual, isErased, isImpure))(n).symbol @tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply) - @tu lazy val ContextFunction0_apply: Symbol = ContextFunction0.requiredMethod(nme.apply) - @tu lazy val Function0: Symbol = FunctionClass(0) - @tu lazy val Function1: Symbol = FunctionClass(1) - @tu lazy val Function2: Symbol = FunctionClass(2) - @tu lazy val ContextFunction0: Symbol = FunctionClass(0, isContextual = true) + @tu lazy val Function0: Symbol = FunctionSymbol(0) + @tu lazy val Function1: Symbol = FunctionSymbol(1) + @tu lazy val Function2: Symbol = FunctionSymbol(2) + @tu lazy val ContextFunction0: Symbol = FunctionSymbol(0, isContextual = true) - def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef = - FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef + def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false, isImpure: Boolean = false)(using Context): TypeRef = + FunctionSymbol(n, isContextual && !ctx.erasedTypes, isErased, isImpure).typeRef lazy val PolyFunctionClass = requiredClass("scala.PolyFunction") def PolyFunctionType = PolyFunctionClass.typeRef @@ -1363,6 +1398,10 @@ class Definitions { */ def isFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isFunction + /** Is a function class, or an impure function type alias */ + def isFunctionSymbol(sym: Symbol): Boolean = + sym.isType && (sym.owner eq ScalaPackageClass) && sym.name.isFunction + /** Is a function class where * - FunctionN for N >= 0 and N != XXL */ @@ -1569,7 +1608,7 @@ class Definitions { def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(using Context): Boolean = paramTypes.length <= 2 - && (cls.derivesFrom(FunctionClass(paramTypes.length)) || isByNameFunctionClass(cls)) + && (cls.derivesFrom(FunctionSymbol(paramTypes.length)) || isByNameFunctionClass(cls)) && isSpecializableFunctionSAM(paramTypes, retType) /** If the Single Abstract Method of a Function class has this type, is it specializable? */ diff --git a/compiler/src/dotty/tools/dotc/core/Flags.scala b/compiler/src/dotty/tools/dotc/core/Flags.scala index cb590e2384a0..f2682621a7bd 100644 --- a/compiler/src/dotty/tools/dotc/core/Flags.scala +++ b/compiler/src/dotty/tools/dotc/core/Flags.scala @@ -314,8 +314,8 @@ object Flags { /** A Scala 2x super accessor / an unpickled Scala 2.x class */ val (SuperParamAliasOrScala2x @ _, SuperParamAlias @ _, Scala2x @ _) = newFlags(26, "", "") - /** A parameter with a default value */ - val (_, HasDefault @ _, _) = newFlags(27, "") + /** A parameter with a default value / an impure untpd.Function type */ + val (_, HasDefault @ _, Impure @ _) = newFlags(27, "", "<{*}>") /** An extension method, or a collective extension instance */ val (Extension @ _, ExtensionMethod @ _, _) = newFlags(28, "") diff --git a/compiler/src/dotty/tools/dotc/core/NameOps.scala b/compiler/src/dotty/tools/dotc/core/NameOps.scala index fb35ac0ac91f..d7c224d90821 100644 --- a/compiler/src/dotty/tools/dotc/core/NameOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NameOps.scala @@ -197,20 +197,25 @@ object NameOps { else collectDigits(acc * 10 + d, idx + 1) collectDigits(0, suffixStart + 8) - /** name[0..suffixStart) == `str` */ - private def isPreceded(str: String, suffixStart: Int) = - str.length == suffixStart && name.firstPart.startsWith(str) + private def isFunctionPrefix(suffixStart: Int, mustHave: String = ""): Boolean = + suffixStart >= 0 + && { + val first = name.firstPart + var found = mustHave.isEmpty + def skip(idx: Int, str: String) = + if first.startsWith(str, idx) then + if str == mustHave then found = true + idx + str.length + else idx + skip(skip(skip(0, "Impure"), "Erased"), "Context") == suffixStart + && found + } /** Same as `funArity`, except that it returns -1 if the prefix * is not one of "", "Context", "Erased", "ErasedContext" */ private def checkedFunArity(suffixStart: Int): Int = - if suffixStart == 0 - || isPreceded("Context", suffixStart) - || isPreceded("Erased", suffixStart) - || isPreceded("ErasedContext", suffixStart) - then funArity(suffixStart) - else -1 + if isFunctionPrefix(suffixStart) then funArity(suffixStart) else -1 /** Is a function name, i.e one of FunctionXXL, FunctionN, ContextFunctionN, ErasedFunctionN, ErasedContextFunctionN for N >= 0 */ @@ -222,19 +227,14 @@ object NameOps { */ def isPlainFunction: Boolean = functionArity >= 0 - /** Is an context function name, i.e one of ContextFunctionN or ErasedContextFunctionN for N >= 0 - */ - def isContextFunction: Boolean = + /** Is a function name that contains `mustHave` as a substring */ + private def isSpecificFunction(mustHave: String): Boolean = val suffixStart = functionSuffixStart - (isPreceded("Context", suffixStart) || isPreceded("ErasedContext", suffixStart)) - && funArity(suffixStart) >= 0 + isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0 - /** Is an erased function name, i.e. one of ErasedFunctionN, ErasedContextFunctionN for N >= 0 - */ - def isErasedFunction: Boolean = - val suffixStart = functionSuffixStart - (isPreceded("Erased", suffixStart) || isPreceded("ErasedContext", suffixStart)) - && funArity(suffixStart) >= 0 + def isContextFunction: Boolean = isSpecificFunction("Context") + def isErasedFunction: Boolean = isSpecificFunction("Erased") + def isImpureFunction: Boolean = isSpecificFunction("Impure") /** Is a synthetic function name, i.e. one of * - FunctionN for N > 22 diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 8ab97925ecaa..bcb9f78c7ad6 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -739,6 +739,8 @@ object StdNames { val XOR : N = "^" val ZAND : N = "&&" val ZOR : N = "||" + val PUREARROW: N = "->" + val PURECTXARROW: N = "?->" // unary operators val UNARY_PREFIX: N = "unary_" diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 0d6f137f3dd9..112ae1e76e3c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2401,7 +2401,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp1: TypeVar if tp1.isInstantiated => tp1.underlying & tp2 case CapturingType(parent1, refs1, _) => - if subCaptures(tp2.captureSet, refs1, frozenConstraint).isOK then + if subCaptures(tp2.captureSet, refs1, frozen = true).isOK then parent1 & tp2 else tp1.derivedCapturingType(parent1 & tp2, refs1) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 8a532c9bf7e5..d0b919bf835b 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -183,7 +183,7 @@ object Types { case _ => false } - /** Is this type a (possibly refined or applied or aliased) type reference + /** Is this type a (possibly refined, applied, aliased or annotated) type reference * to the given type symbol? * @sym The symbol to compare to. It must be a class symbol or abstract type. * It makes no sense for it to be an alias type because isRef would always @@ -204,9 +204,7 @@ object Types { case this1: TypeVar => this1.instanceOpt.isRef(sym, skipRefined) case this1: AnnotatedType => - this1 match - case CapturingType(_, _, _) => false - case _ => this1.parent.isRef(sym, skipRefined) + this1.parent.isRef(sym, skipRefined) case _ => false } diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 2df71a25766a..3beffbbf3b05 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1364,47 +1364,56 @@ object Parsers { * | InfixType * | CaptureSet Type * FunType ::= (MonoFunType | PolyFunType) - * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type - * PolyFunType ::= HKTypeParamClause '=>' Type + * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’ | ‘->’ | ‘?->’ ) Type + * PolyFunType ::= HKTypeParamClause ('=>' | ‘->’_) Type * FunTypeArgs ::= InfixType * | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)' * | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')' * CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` * CaptureRef ::= Ident */ - def typ(): Tree = { + def typ(): Tree = val start = in.offset var imods = Modifiers() def functionRest(params: List[Tree]): Tree = val paramSpan = Span(start, in.lastOffset) atSpan(start, in.offset) { - if in.token == TLARROW then + var token = in.token + if in.isIdent(nme.PUREARROW) then + token = ARROW + else if in.isIdent(nme.PURECTXARROW) then + token = CTXARROW + else if token == TLARROW then if !imods.flags.isEmpty || params.isEmpty then syntaxError(em"illegal parameter list for type lambda", start) - in.token = ARROW - else - for case ValDef(_, tpt: ByNameTypeTree, _) <- params do - syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span) - in.nextToken() - return TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], typ()) + token = ARROW + else if ctx.settings.Ycc.value then + // `=>` means impure function under -Ycc whereas `->` is a regular function. + // Without -Ycc they both mean regular function. + imods |= Impure - if in.token == CTXARROW then + if token == CTXARROW then in.nextToken() imods |= Given + else if token == ARROW || token == TLARROW then + in.nextToken() else accept(ARROW) - val t = typ() - if imods.isOneOf(Given | Erased) then + val resultType = typ() + if token == TLARROW then + for case ValDef(_, tpt: ByNameTypeTree, _) <- params do + syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span) + TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType) + else if imods.isOneOf(Given | Erased | Impure) then if imods.is(Given) && params.isEmpty then syntaxError("context function types require at least one parameter", paramSpan) - new FunctionWithMods(params, t, imods) + FunctionWithMods(params, resultType, imods) else if !ctx.settings.YkindProjector.isDefault then - val (newParams :+ newT, tparams) = replaceKindProjectorPlaceholders(params :+ t) - - lambdaAbstract(tparams, Function(newParams, newT)) + val (newParams :+ newResultType, tparams) = replaceKindProjectorPlaceholders(params :+ resultType) + lambdaAbstract(tparams, Function(newParams, newResultType)) else - Function(params, t) + Function(params, resultType) } def funTypeArgsRest(first: Tree, following: () => Tree) = { val buf = new ListBuffer[Tree] += first @@ -1462,7 +1471,7 @@ object Parsers { val tparams = typeParamClause(ParamOwner.TypeParam) if (in.token == TLARROW) atSpan(start, in.skipToken())(LambdaTypeTree(tparams, toplevelTyp())) - else if (in.token == ARROW) { + else if (in.token == ARROW || in.isIdent(nme.PUREARROW)) { val arrowOffset = in.skipToken() val body = toplevelTyp() atSpan(start, arrowOffset) { @@ -1483,16 +1492,18 @@ object Parsers { else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() - in.token match { + in.token match case ARROW | CTXARROW => functionRest(t :: Nil) case MATCH => matchType(t) case FORSOME => syntaxError(ExistentialTypesNoLongerSupported()); t case _ => - if (imods.is(Erased) && !t.isInstanceOf[FunctionWithMods]) - syntaxError(ErasedTypesCanOnlyBeFunctionTypes(), implicitKwPos(start)) - t - } - } + if isIdent(nme.PUREARROW) || isIdent(nme.PURECTXARROW) then + functionRest(t :: Nil) + else + if (imods.is(Erased) && !t.isInstanceOf[FunctionWithMods]) + syntaxError(ErasedTypesCanOnlyBeFunctionTypes(), implicitKwPos(start)) + t + end typ private def makeKindProjectorTypeDef(name: TypeName): TypeDef = { val isVarianceAnnotated = name.startsWith("+") || name.startsWith("-") @@ -1550,7 +1561,7 @@ object Parsers { def infixTypeRest(t: Tree): Tree = infixOps(t, canStartInfixTypeTokens, refinedTypeFn, Location.ElseWhere, isType = true, - isOperator = !followingIsVararg()) + isOperator = !followingIsVararg() && !isIdent(nme.PUREARROW) && !isIdent(nme.PURECTXARROW)) /** RefinedType ::= WithType {[nl] Refinement} */ diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index e2513ec7b9df..6409d37ef735 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -15,7 +15,7 @@ import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch import config.Config -import cc.{CapturingType, CaptureSet} +import cc.{EventuallyCapturingType, CaptureSet} class PlainPrinter(_ctx: Context) extends Printer { @@ -143,6 +143,9 @@ class PlainPrinter(_ctx: Context) extends Printer { + defn.ObjectClass + defn.FromJavaObjectSymbol + def toText(cs: CaptureSet): Text = + "{" ~ Text(cs.elems.toList.map(toTextCaptureRef), ", ") ~ "}" + def toText(tp: Type): Text = controlled { homogenize(tp) match { case tp: TypeType => @@ -197,7 +200,7 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close - case CapturingType(parent, refs, boxed) => + case EventuallyCapturingType(parent, refs, boxed) => def box = Str("box ") provided boxed if printDebug && !refs.isConst then changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent)) @@ -206,11 +209,7 @@ class PlainPrinter(_ctx: Context) extends Printer { else if !refs.isConst && refs.elems.isEmpty then changePrec(GlobalPrec)("?" ~ " " ~ toText(parent)) else if Config.printCaptureSetsAsPrefix then - changePrec(GlobalPrec)( - box ~ "{" - ~ Text(refs.elems.toList.map(toTextCaptureRef), ", ") - ~ "} " - ~ toText(parent)) + changePrec(GlobalPrec)(box ~ toText(refs) ~ " " ~ toText(parent)) else changePrec(InfixPrec)(toText(parent) ~ " retains " ~ box ~ toText(refs.toRetainsTypeArg)) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 2fb1715d4cfc..88fac70581ab 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -15,7 +15,7 @@ import Annotations.Annotation import Denotations._ import SymDenotations._ import StdNames.{nme, tpnme} -import ast.{Trees, untpd} +import ast.{Trees, tpd, untpd} import typer.{Implicits, Namer, Applications} import typer.ProtoTypes._ import Trees._ @@ -25,10 +25,12 @@ import NameKinds.{WildcardParamName, DefaultGetterName} import util.Chars.isOperatorPart import transform.TypeUtils._ import transform.SymUtils._ +import config.Config import language.implicitConversions import dotty.tools.dotc.util.{NameTransformer, SourcePosition} import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef} +import cc.{EventuallyCapturingType, CaptureSet, toCaptureSet, IllegalCaptureRef} class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { @@ -136,14 +138,14 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else simpleNameString(tsym) } - private def arrow(isGiven: Boolean): String = - if isGiven then "?=>" else "=>" + private def arrow(isGiven: Boolean, isPure: Boolean): String = + (if isGiven then "?" else "") + (if isPure then "->" else "=>") override def toText(tp: Type): Text = controlled { def toTextTuple(args: List[Type]): Text = "(" ~ argsText(args) ~ ")" - def toTextFunction(args: List[Type], isGiven: Boolean, isErased: Boolean): Text = + def toTextFunction(args: List[Type], isGiven: Boolean, isErased: Boolean, isPure: Boolean): Text = changePrec(GlobalPrec) { val argStr: Text = if args.length == 2 @@ -156,26 +158,26 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { ~ keywordText("erased ").provided(isErased) ~ argsText(args.init) ~ ")" - argStr ~ " " ~ arrow(isGiven) ~ " " ~ argText(args.last) + argStr ~ " " ~ arrow(isGiven, isPure) ~ " " ~ argText(args.last) } - def toTextMethodAsFunction(info: Type): Text = info match + def toTextMethodAsFunction(info: Type, isPure: Boolean): Text = info match case info: MethodType => changePrec(GlobalPrec) { "(" ~ keywordText("erased ").provided(info.isErasedMethod) ~ paramsText(info) ~ ") " - ~ arrow(info.isImplicitMethod) + ~ arrow(info.isImplicitMethod, isPure) ~ " " - ~ toTextMethodAsFunction(info.resultType) + ~ toTextMethodAsFunction(info.resultType, isPure) } case info: PolyType => changePrec(GlobalPrec) { "[" ~ paramsText(info) ~ "] => " - ~ toTextMethodAsFunction(info.resultType) + ~ toTextMethodAsFunction(info.resultType, isPure) } case _ => toText(info) @@ -214,9 +216,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def appliedText(tp: Type): Text = tp match case tp @ AppliedType(tycon, args) => - val cls = tycon.typeSymbol + val tsym = tycon.typeSymbol if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*" - else if defn.isFunctionClass(cls) then toTextFunction(args, cls.name.isContextFunction, cls.name.isErasedFunction) + else if defn.isFunctionSymbol(tsym) then + toTextFunction(args, tsym.name.isContextFunction, tsym.name.isErasedFunction, + isPure = ctx.settings.Ycc.value && !tsym.name.isImpureFunction) else if tp.tupleArity >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes) else if isInfixType(tp) then val l :: r :: Nil = args @@ -243,7 +247,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { // don't eta contract if the application would be printed specially toText(tycon) case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug => - toTextMethodAsFunction(tp.refinedInfo) + toTextMethodAsFunction(tp.refinedInfo, + isPure = ctx.settings.Ycc.value && !tp.typeSymbol.name.isImpureFunction) case tp: TypeRef => if (tp.symbol.isAnonymousClass && !showUniqueIds) toText(tp.info) @@ -259,7 +264,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case tp: ClassInfo => if tp.cls.derivesFrom(defn.PolyFunctionClass) then tp.member(nme.apply).info match - case info: PolyType => return toTextMethodAsFunction(info) + case info: PolyType => return toTextMethodAsFunction(info, isPure = false) case _ => toTextParents(tp.parents) ~~ "{...}" case JavaArrayType(elemtp) => @@ -527,7 +532,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(OrTypePrec) { toText(args(0)) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(args(1)) } } else if (tpt.symbol == defn.andType && args.length == 2) changePrec(AndTypePrec) { toText(args(0)) ~ " & " ~ atPrec(AndTypePrec + 1) { toText(args(1)) } } - else if defn.isFunctionClass(tpt.symbol) + else if defn.isFunctionSymbol(tpt.symbol) && tpt.isInstanceOf[TypeTree] && tree.hasType && !printDebug then changePrec(GlobalPrec) { toText(tree.typeOpt) } else args match @@ -602,7 +607,17 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { case tree: Template => toTextTemplate(tree) case Annotated(arg, annot) => - toTextLocal(arg) ~~ annotText(annot.symbol.enclosingClass, annot) + def captureSet = + annot.asInstanceOf[tpd.Tree].toCaptureSet + def toTextAnnot = + toTextLocal(arg) ~~ annotText(annot.symbol.enclosingClass, annot) + def toTextRetainsAnnot = + try changePrec(GlobalPrec)(toText(captureSet) ~ " " ~ toText(arg)) + catch case ex: IllegalCaptureRef => toTextAnnot + if annot.symbol.maybeOwner == defn.RetainsAnnot + && ctx.settings.Ycc.value && Config.printCaptureSetsAsPrefix && !printDebug + then toTextRetainsAnnot + else toTextAnnot case EmptyTree => "" case TypedSplice(t) => @@ -645,7 +660,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { ~ Text(args.map(argToText), ", ") ~ ")" } - argsText ~ " " ~ arrow(isGiven) ~ " " ~ toText(body) + val isPure = + ctx.settings.Ycc.value + && tree.match + case tree: FunctionWithMods => !tree.mods.is(Impure) + case _ => true + argsText ~ " " ~ arrow(isGiven, isPure) ~ " " ~ toText(body) case PolyFunction(targs, body) => val targsText = "[" ~ Text(targs.map((arg: Tree) => toText(arg)), ", ") ~ "]" changePrec(GlobalPrec) { diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 1f30dd989f3a..e78776cb0158 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -117,82 +117,150 @@ class CheckCaptures extends Recheck: override def transformType(tp: Type, inferred: Boolean, boxed: Boolean)(using Context): Type = - def addInnerVars(tp: Type): Type = tp match - case tp @ AppliedType(tycon, args) => - tp.derivedAppliedType(tycon, args.map(addVars(_, boxed = true))) - case tp @ RefinedType(core, rname, rinfo) => - val rinfo1 = addVars(rinfo) - if defn.isFunctionType(tp) then - rinfo1.toFunctionType(isJava = false, alwaysDependent = true) - else - tp.derivedRefinedType(addInnerVars(core), rname, rinfo1) - case tp: MethodType => - tp.derivedLambdaType( - paramInfos = tp.paramInfos.mapConserve(addVars(_)), - resType = addVars(tp.resType)) - case tp: PolyType => - tp.derivedLambdaType( - resType = addVars(tp.resType)) - case tp: ExprType => - tp.derivedExprType(resType = addVars(tp.resType)) - case _ => - tp - - def addFunctionRefinements(tp: Type): Type = tp match - case tp @ AppliedType(tycon, args) => - if defn.isNonRefinedFunction(tp) then - MethodType.companion( - isContextual = defn.isContextFunctionClass(tycon.classSymbol), - isErased = defn.isErasedFunctionClass(tycon.classSymbol) - )(args.init, addFunctionRefinements(args.last)) - .toFunctionType(isJava = false, alwaysDependent = true) - .showing(i"add function refinement $tp --> $result", capt) - else - tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_))) - case tp @ RefinedType(core, rname, rinfo) if !defn.isFunctionType(tp) => - tp.derivedRefinedType( - addFunctionRefinements(core), rname, addFunctionRefinements(rinfo)) - case tp: MethodOrPoly => - tp.derivedLambdaType(resType = addFunctionRefinements(tp.resType)) - case tp: ExprType => - tp.derivedExprType(resType = addFunctionRefinements(tp.resType)) - case _ => - tp - - /** Refine a possibly applied class type C where the class has tracked parameters - * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } - * where CV_1, ..., CV_n are fresh capture sets. + def depFun(tycon: Type, argTypes: List[Type], resType: Type): Type = + MethodType.companion( + isContextual = defn.isContextFunctionClass(tycon.classSymbol), + isErased = defn.isErasedFunctionClass(tycon.classSymbol) + )(argTypes, resType) + .toFunctionType(isJava = false, alwaysDependent = true) + + def box(tp: Type): Type = tp match + case CapturingType(parent, refs, false) => CapturingType(parent, refs, true) + case _ => tp + + /** Perform the following transformation steps everywhere in a type: + * 1. Drop retains annotations + * 2. Turn plain function types into dependent function types, so that + * we can refer to their parameter in capture sets. Currently this is + * only done at the toplevel, i.e. for function types that are not + * themselves argument types of other function types. Without this restriction + * boxmap-paper.scala fails. Need to figure out why. + * 3. Refine other class types C by adding capture set variables to their parameter getters + * (see addCaptureRefinements) + * 4. Add capture set variables to all types that can be tracked + * + * Polytype bounds are only cleaned using step 1, but not otherwise transformed. */ - def addCaptureRefinements(tp: Type): Type = tp.stripped match - case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass => - val cls = tp.typeSymbol.asClass - cls.paramGetters.foldLeft(tp) { (core, getter) => - if getter.termRef.isTracked then - val getterType = tp.memberInfo(getter).strippedDealias - RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) - .showing(i"add capture refinement $tp --> $result", capt) - else - core - } - case _ => - tp + def mapInferred = new TypeMap: - def addVars(tp: Type, boxed: Boolean = false): Type = - var tp1 = addInnerVars(tp) - val tp2 = addCaptureRefinements(tp1) - if tp1.canHaveInferredCapture - then CapturingType(tp2, CaptureSet.Var(), boxed) - else tp2 - - if inferred then - val cleanup = new TypeMap: + /** Drop @retains annotations everywhere */ + object cleanup extends TypeMap: def apply(t: Type) = t match case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => apply(parent) case _ => mapOver(t) - addVars(addFunctionRefinements(cleanup(tp)), boxed) - .showing(i"reinfer $tp --> $result", capt) + + /** Refine a possibly applied class type C where the class has tracked parameters + * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } + * where CV_1, ..., CV_n are fresh capture sets. + */ + def addCaptureRefinements(tp: Type): Type = tp match + case _: TypeRef | _: AppliedType if tp.typeParams.isEmpty => + tp.typeSymbol match + case cls: ClassSymbol if !defn.isFunctionClass(cls) => + cls.paramGetters.foldLeft(tp) { (core, getter) => + if getter.termRef.isTracked then + val getterType = tp.memberInfo(getter).strippedDealias + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + .showing(i"add capture refinement $tp --> $result", capt) + else + core + } + case _ => tp + case _ => tp + + /** Should a capture set variable be added on type `tp`? */ + def canHaveInferredCapture(tp: Type): Boolean = + tp.typeParams.isEmpty && tp.match + case tp: (TypeRef | AppliedType) => + val sym = tp.typeSymbol + if sym.isClass then !sym.isValueClass && sym != defn.AnyClass + else canHaveInferredCapture(tp.superType.dealias) + case tp: (RefinedOrRecType | MatchType) => + canHaveInferredCapture(tp.underlying) + case tp: AndType => + canHaveInferredCapture(tp.tp1) && canHaveInferredCapture(tp.tp2) + case tp: OrType => + canHaveInferredCapture(tp.tp1) || canHaveInferredCapture(tp.tp2) + case _ => + false + + /** Add a capture set variable to `tp` if necessary, or maybe pull out + * an embedded capture set variables from a part of `tp`. + */ + def addVar(tp: Type) = tp match + case tp @ RefinedType(parent @ CapturingType(parent1, refs, boxed), rname, rinfo) => + CapturingType(tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed) + case tp: RecType => + tp.parent match + case CapturingType(parent1, refs, boxed) => + CapturingType(tp.derivedRecType(parent1), refs, boxed) + case _ => + tp // can return `tp` here since unlike RefinedTypes, RecTypes are never created + // by `mapInferred`. Hence if the underlying type admits capture variables + // a variable was already added, and the first case above would apply. + case AndType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => + assert(refs1.asVar.elems.isEmpty) + assert(refs2.asVar.elems.isEmpty) + assert(boxed1 == boxed2) + CapturingType(AndType(parent1, parent2), refs1, boxed1) + case tp @ OrType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => + assert(refs1.asVar.elems.isEmpty) + assert(refs2.asVar.elems.isEmpty) + assert(boxed1 == boxed2) + CapturingType(OrType(parent1, parent2, tp.isSoft), refs1, boxed1) + case tp @ OrType(CapturingType(parent1, refs1, boxed1), tp2) => + CapturingType(OrType(parent1, tp2, tp.isSoft), refs1, boxed1) + case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) => + CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2) + case _ if canHaveInferredCapture(tp) => + CapturingType(tp, CaptureSet.Var(), boxed = false) + case _ => + tp + + var isTopLevel = true + + def mapNested(ts: List[Type]): List[Type] = + val saved = isTopLevel + isTopLevel = false + try ts.mapConserve(this) finally isTopLevel = saved + + def apply(t: Type) = + val t1 = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case tp @ AppliedType(tycon, args) => + val tycon1 = this(tycon) + if defn.isNonRefinedFunction(tp) then + val args1 = mapNested(args.init) + val res1 = this(args.last) + if isTopLevel then + depFun(tycon1, args1, res1) + .showing(i"add function refinement $tp --> $result", capt) + else + tp.derivedAppliedType(tycon1, args1 :+ res1) + else + tp.derivedAppliedType(tycon1, args.mapConserve(arg => box(this(arg)))) + case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) => + apply(rinfo).toFunctionType(isJava = false, alwaysDependent = true) + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = mapNested(tp.paramInfos), + resType = this(tp.resType)) + case tp: TypeLambda => + // Don't recurse into parameter bounds, just cleanup any stray retains annotations + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds), + resType = this(tp.resType)) + case _ => + mapOver(t) + addVar(addCaptureRefinements(t1)) + end mapInferred + + if inferred then + val tp1 = mapInferred(tp) + if boxed then box(tp1) else tp1 else def setBoxed(t: Type) = t match case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 4824031f12bc..ad22b4cb1192 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -822,7 +822,7 @@ trait Implicits: def isOldStyleFunctionConversion(tpe: Type): Boolean = tpe match { case PolyType(_, resType) => isOldStyleFunctionConversion(resType) - case _ => tpe.derivesFrom(defn.FunctionClass(1)) && !tpe.derivesFrom(defn.ConversionClass) && !tpe.derivesFrom(defn.SubTypeClass) + case _ => tpe.derivesFrom(defn.FunctionSymbol(1)) && !tpe.derivesFrom(defn.ConversionClass) && !tpe.derivesFrom(defn.SubTypeClass) } try diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index bb871654ea5e..950afa76a2e3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1253,7 +1253,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val numArgs = args.length val isContextual = funFlags.is(Given) val isErased = funFlags.is(Erased) - val funCls = defn.FunctionClass(numArgs, isContextual, isErased) + val isImpure = funFlags.is(Impure) + val funSym = defn.FunctionSymbol(numArgs, isContextual, isErased, isImpure) /** If `app` is a function type with arguments that are all erased classes, * turn it into an erased function type. @@ -1263,7 +1264,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if !isErased && numArgs > 0 && args.indexWhere(!_.tpe.isErasedClass) == numArgs => - val tycon1 = TypeTree(defn.FunctionClass(numArgs, isContextual, isErased = true).typeRef) + val tycon1 = TypeTree(defn.FunctionSymbol(numArgs, isContextual, true, isImpure).typeRef) .withSpan(tycon.span) assignType(cpy.AppliedTypeTree(app)(tycon1, args), tycon1, args) case _ => @@ -1290,7 +1291,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer report.error(i"$mt is an illegal function type because it has inter-parameter dependencies", tree.srcPos) val resTpt = TypeTree(mt.nonDependentResultApprox).withSpan(body.span) val typeArgs = appDef.termParamss.head.map(_.tpt) :+ resTpt - val tycon = TypeTree(funCls.typeRef) + val tycon = TypeTree(funSym.typeRef) val core = propagateErased(AppliedTypeTree(tycon, typeArgs)) RefinedTypeTree(core, List(appDef), ctx.owner.asClass) end typedDependent @@ -1301,7 +1302,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer using ctx.fresh.setOwner(newRefinedClassSymbol(tree.span)).setNewScope) case _ => propagateErased( - typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funCls.typeRef), args :+ body), pt)) + typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)) } } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index ae3067a894d7..82ed941a8598 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -2698,7 +2698,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler def SomeModule: Symbol = dotc.core.Symbols.defn.SomeClass.companionModule def ProductClass: Symbol = dotc.core.Symbols.defn.ProductClass def FunctionClass(arity: Int, isImplicit: Boolean = false, isErased: Boolean = false): Symbol = - dotc.core.Symbols.defn.FunctionClass(arity, isImplicit, isErased) + dotc.core.Symbols.defn.FunctionSymbol(arity, isImplicit, isErased) def TupleClass(arity: Int): Symbol = dotc.core.Symbols.defn.TupleType(arity).classSymbol.asClass def isTupleClass(sym: Symbol): Boolean = diff --git a/tests/neg-custom-args/capt-wf.scala b/tests/neg-custom-args/capt-wf.scala index dc4d6a0d4bff..3bd80e0d0f68 100644 --- a/tests/neg-custom-args/capt-wf.scala +++ b/tests/neg-custom-args/capt-wf.scala @@ -8,7 +8,7 @@ def test(c: Cap, other: String): Unit = val x2: {other} C = ??? // error: cs is empty val s1 = () => "abc" val x3: {s1} C = ??? // error: cs is empty - val x3a: () => String = s1 + val x3a: () -> String = s1 val s2 = () => if x1 == null then "" else "abc" val x4: {s2} C = ??? // OK val x5: {c, c} C = ??? // error: redundant @@ -26,8 +26,8 @@ def test(c: Cap, other: String): Unit = val y1: {e1} String = ??? // error cs is empty val y2: {o1} String = ??? // error cs is empty - lazy val ev: (Int => Boolean) = (n: Int) => - lazy val od: (Int => Boolean) = (n: Int) => + lazy val ev: (Int -> Boolean) = (n: Int) => + lazy val od: (Int -> Boolean) = (n: Int) => if n == 1 then true else ev(n - 1) if n == 0 then true else od(n - 1) val y3: {ev} String = ??? // error cs is empty diff --git a/tests/neg-custom-args/captures/bounded.scala b/tests/neg-custom-args/captures/bounded.scala index dc2621e95a65..fb6b198fb3a3 100644 --- a/tests/neg-custom-args/captures/bounded.scala +++ b/tests/neg-custom-args/captures/bounded.scala @@ -9,6 +9,6 @@ def test(c: Cap) = def f(x: Int): Int = if c == c then x else 0 val b = new B(f) val r1 = b.elem - val r1c: {c} Int => Int = r1 + val r1c: {c} Int -> Int = r1 val r2 = b.lateElem - val r2c: () => {c} Int => Int = r2 // error \ No newline at end of file + val r2c: () -> {c} Int -> Int = r2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/boxmap.check b/tests/neg-custom-args/captures/boxmap.check index 406077077af5..b3d6605989bf 100644 --- a/tests/neg-custom-args/captures/boxmap.check +++ b/tests/neg-custom-args/captures/boxmap.check @@ -1,7 +1,7 @@ --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:14:2 ---------------------------------------- -14 | () => b[Box[B]]((x: A) => box(f(x))) // error +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:12:2 ---------------------------------------- +12 | () => b[Box[B]]((x: A) => box(f(x))) // error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {f} () => ? Box[B] - | Required: () => Box[B] + | Found: {f} () -> ? Box[B] + | Required: () -> Box[B] longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/boxmap.scala b/tests/neg-custom-args/captures/boxmap.scala index e335320ef9d4..114aaccb6bb5 100644 --- a/tests/neg-custom-args/captures/boxmap.scala +++ b/tests/neg-custom-args/captures/boxmap.scala @@ -1,14 +1,12 @@ type Top = Any @retains(*) -infix type ==> [A, B] = (A => B) @retains(*) - -type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) +type Box[+T <: Top] = ([K <: Top] -> (T => K) -> K) def box[T <: Top](x: T): Box[T] = - [K <: Top] => (k: T ==> K) => k(x) + [K <: Top] => (k: T => K) => k(x) -def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = +def map[A <: Top, B <: Top](b: Box[A])(f: A => B): Box[B] = b[Box[B]]((x: A) => box(f(x))) -def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): () => Box[B] = +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A => B): () -> Box[B] = () => b[Box[B]]((x: A) => box(f(x))) // error diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala index ef5876be2c11..feb9461dc4c7 100644 --- a/tests/neg-custom-args/captures/byname.scala +++ b/tests/neg-custom-args/captures/byname.scala @@ -3,7 +3,7 @@ def test(cap1: Cap, cap2: Cap) = def f() = if cap1 == cap1 then g else g def g(x: Int) = if cap2 == cap2 then 1 else x - def h(ff: => {cap2} Int => Int) = ff + def h(ff: => {cap2} Int -> Int) = ff h(f()) // error diff --git a/tests/neg-custom-args/captures/capt-box-env.scala b/tests/neg-custom-args/captures/capt-box-env.scala index e9743054076e..605b446d5262 100644 --- a/tests/neg-custom-args/captures/capt-box-env.scala +++ b/tests/neg-custom-args/captures/capt-box-env.scala @@ -1,5 +1,4 @@ -class C -type Cap = {*} C +@annotation.capability class Cap class Pair[+A, +B](x: A, y: B): def fst: A = x @@ -9,4 +8,4 @@ def test(c: Cap) = def f(x: Cap): Unit = if c == x then () val p = Pair(f, f) val g = () => p.fst == p.snd - val gc: () => Boolean = g // error + val gc: () -> Boolean = g // error diff --git a/tests/neg-custom-args/captures/capt-box.scala b/tests/neg-custom-args/captures/capt-box.scala index 317fc064ec0b..634470704fc5 100644 --- a/tests/neg-custom-args/captures/capt-box.scala +++ b/tests/neg-custom-args/captures/capt-box.scala @@ -1,6 +1,4 @@ -//import scala.retains -class C -type Cap = {*} C +@annotation.capability class Cap def test(x: Cap) = @@ -10,4 +8,4 @@ def test(x: Cap) = val x2 = identity(x1) - val x3: Cap => Unit = x2 // error \ No newline at end of file + val x3: Cap -> Unit = x2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-env.scala b/tests/neg-custom-args/captures/capt-env.scala index 84b4b57a7930..52fa4abfdaa8 100644 --- a/tests/neg-custom-args/captures/capt-env.scala +++ b/tests/neg-custom-args/captures/capt-env.scala @@ -9,5 +9,5 @@ def test(c: Cap) = def f(x: Cap): Unit = if c == x then () val p = Pair(f, f) val g = () => p.fst == p.snd - val gc: () => Boolean = g // error + val gc: () -> Boolean = g // error diff --git a/tests/neg-custom-args/captures/capt1.check b/tests/neg-custom-args/captures/capt1.check index ce7c4833bf9c..0b99f1bac09e 100644 --- a/tests/neg-custom-args/captures/capt1.check +++ b/tests/neg-custom-args/captures/capt1.check @@ -1,21 +1,21 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:3:2 ------------------------------------------ 3 | () => if x == null then y else y // error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {x} () => ? C - | Required: () => C + | Found: {x} () -> ? C + | Required: () -> C longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:6:2 ------------------------------------------ 6 | () => if x == null then y else y // error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {x} () => ? C + | Found: {x} () -> ? C | Required: Matchable longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:13:2 ----------------------------------------- 13 | def f(y: Int) = if x == null then y else y // error | ^ - | Found: {x} Int => Int + | Found: {x} Int -> Int | Required: Matchable 14 | f @@ -38,9 +38,9 @@ longer explanation available when compiling with `-explain` longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ---------------------------------------- -31 | val z2 = h[() => Cap](() => x)(() => C()) // error +31 | val z2 = h[() -> Cap](() => x)(() => C()) // error | ^^^^^^^ - | Found: {x} () => ? Cap - | Required: () => Cap + | Found: {x} () -> Cap + | Required: () -> Cap longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/capt1.scala b/tests/neg-custom-args/captures/capt1.scala index 4da49c5f4f1e..e230defda170 100644 --- a/tests/neg-custom-args/captures/capt1.scala +++ b/tests/neg-custom-args/captures/capt1.scala @@ -1,5 +1,5 @@ class C -def f(x: C @retains(*), y: C): () => C = +def f(x: C @retains(*), y: C): () -> C = () => if x == null then y else y // error def g(x: C @retains(*), y: C): Matchable = @@ -28,7 +28,7 @@ def h4(x: Cap, y: Int): A = def foo() = val x: C @retains(*) = ??? def h[X](a: X)(b: X) = a - val z2 = h[() => Cap](() => x)(() => C()) // error - val z3 = h[(() => Cap) @retains(x)](() => x)(() => C()) // ok - val z4 = h[(() => Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3 + val z2 = h[() -> Cap](() => x)(() => C()) // error + val z3 = h[(() -> Cap) @retains(x)](() => x)(() => C()) // ok + val z4 = h[(() -> Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3 diff --git a/tests/neg-custom-args/captures/capt2.scala b/tests/neg-custom-args/captures/capt2.scala index 1eee53463f6d..8b08832dfdb9 100644 --- a/tests/neg-custom-args/captures/capt2.scala +++ b/tests/neg-custom-args/captures/capt2.scala @@ -2,8 +2,8 @@ class C type Cap = {*} C -def f1(c: Cap): (() => {c} C) = () => c // error, but would be OK under capture abbreciations for funciton types -def f2(c: Cap): ({c} () => C) = () => c // error +def f1(c: Cap): (() -> {c} C) = () => c // error, but would be OK under capture abbreciations for funciton types +def f2(c: Cap): ({c} () -> C) = () => c // error -def h5(x: Cap): () => C = +def h5(x: Cap): () -> C = f1(x) // error diff --git a/tests/neg-custom-args/captures/capt3.scala b/tests/neg-custom-args/captures/capt3.scala index 80b937276f73..6e9ea02fe8e3 100644 --- a/tests/neg-custom-args/captures/capt3.scala +++ b/tests/neg-custom-args/captures/capt3.scala @@ -5,22 +5,22 @@ def test1() = val x: Cap = C() val y = () => { x; () } val z = y - z: (() => Unit) // error + z: (() -> Unit) // error def test2() = val x: Cap = C() def y = () => { x; () } def z = y - z: (() => Unit) // error + z: (() -> Unit) // error def test3() = val x: Cap = C() def y = () => { x; () } val z = y - z: (() => Unit) // error + z: (() -> Unit) // error def test4() = val x: Cap = C() val y = () => { x; () } def z = y - z: (() => Unit) // error + z: (() -> Unit) // error diff --git a/tests/neg-custom-args/captures/io.scala b/tests/neg-custom-args/captures/io.scala index 17c22a2111e4..c0cb11686b32 100644 --- a/tests/neg-custom-args/captures/io.scala +++ b/tests/neg-custom-args/captures/io.scala @@ -4,13 +4,13 @@ sealed trait IO: def test1 = val IO : IO @retains(*) = new IO {} def foo = {IO; IO.puts("hello") } - val x : () => Unit = () => foo // error: Found: (() => Unit) retains IO; Required: () => Unit + val x : () -> Unit = () => foo // error: Found: (() -> Unit) retains IO; Required: () -> Unit def test2 = val IO : IO @retains(*) = new IO {} def puts(msg: Any, io: IO @retains(*)) = println(msg) def foo() = puts("hello", IO) - val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + val x : () -> Unit = () => foo() // error: Found: (() -> Unit) retains IO; Required: () -> Unit type Capability[T] = T @retains(*) @@ -18,5 +18,5 @@ def test3 = val IO : Capability[IO] = new IO {} def puts(msg: Any, io: Capability[IO]) = println(msg) def foo() = puts("hello", IO) - val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + val x : () -> Unit = () => foo() // error: Found: (() -> Unit) retains IO; Required: () -> Unit diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check index 3a80de9bdf16..0de190df8f11 100644 --- a/tests/neg-custom-args/captures/lazylist.check +++ b/tests/neg-custom-args/captures/lazylist.check @@ -8,7 +8,7 @@ longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 ------------------------------------- 35 | val ref1c: LazyList[Int] = ref1 // error | ^^^^ - | Found: (ref1 : {cap1} lazylists.LazyCons[Int]{xs: {cap1} () => {*} lazylists.LazyList[Int]}) + | Found: (ref1 : {cap1} lazylists.LazyCons[Int]{xs: {cap1} () -> {*} lazylists.LazyList[Int]}) | Required: lazylists.LazyList[Int] longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylist.scala b/tests/neg-custom-args/captures/lazylist.scala index f7be43e8dc27..56bfc3ea6da2 100644 --- a/tests/neg-custom-args/captures/lazylist.scala +++ b/tests/neg-custom-args/captures/lazylist.scala @@ -7,11 +7,11 @@ abstract class LazyList[+T]: def head: T def tail: LazyList[T] - def map[U](f: {*} T => U): {f, this} LazyList[U] = + def map[U](f: T => U): {f, this} LazyList[U] = if isEmpty then LazyNil else LazyCons(f(head), () => tail.map(f)) -class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: +class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]: def isEmpty = false def head = x def tail = xs() // error: cannot have an inferred type @@ -21,7 +21,7 @@ object LazyNil extends LazyList[Nothing]: def head = ??? def tail: {*} LazyList[Nothing] = ??? // error overriding -def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = +def map[A, B](xs: {*} LazyList[A], f: A => B): {f, xs} LazyList[B] = xs.map(f) class CC diff --git a/tests/neg-custom-args/captures/lazylists1.scala b/tests/neg-custom-args/captures/lazylists1.scala index 02c7cb4ff3e5..4091ee2c62ae 100644 --- a/tests/neg-custom-args/captures/lazylists1.scala +++ b/tests/neg-custom-args/captures/lazylists1.scala @@ -14,7 +14,7 @@ object LazyNil extends LazyList[Nothing]: def tail = ??? extension [A](xs: {*} LazyList[A]) - def map[B](f: {*} A => B): {xs, f} LazyList[B] = + def map[B](f: A => B): {xs, f} LazyList[B] = final class Mapped extends LazyList[B]: this: ({xs, f} Mapped) => diff --git a/tests/neg-custom-args/captures/lazylists2.scala b/tests/neg-custom-args/captures/lazylists2.scala index c31a1ae5d04f..b9ebb0a7a9f0 100644 --- a/tests/neg-custom-args/captures/lazylists2.scala +++ b/tests/neg-custom-args/captures/lazylists2.scala @@ -14,7 +14,7 @@ object LazyNil extends LazyList[Nothing]: def tail = ??? extension [A](xs: {*} LazyList[A]) - def map[B](f: {*} A => B): {f} LazyList[B] = + def map[B](f: A => B): {f} LazyList[B] = final class Mapped extends LazyList[B]: // error this: ({xs, f} Mapped) => @@ -23,7 +23,7 @@ extension [A](xs: {*} LazyList[A]) def tail: {this} LazyList[B] = xs.tail.map(f) new Mapped - def map2[B](f: {*} A => B): {xs} LazyList[B] = + def map2[B](f: A => B): {xs} LazyList[B] = final class Mapped extends LazyList[B]: // error this: ({xs, f} Mapped) => @@ -32,7 +32,7 @@ extension [A](xs: {*} LazyList[A]) def tail: {this} LazyList[B] = xs.tail.map(f) new Mapped - def map3[B](f: {*} A => B): {xs} LazyList[B] = + def map3[B](f: A => B): {xs} LazyList[B] = final class Mapped extends LazyList[B]: this: ({xs} Mapped) => @@ -41,7 +41,7 @@ extension [A](xs: {*} LazyList[A]) def tail: {this} LazyList[B] = xs.tail.map(f) // error new Mapped - def map4[B](f: {*} A => B): {xs} LazyList[B] = + def map4[B](f: A => B): {xs} LazyList[B] = final class Mapped extends LazyList[B]: this: ({xs, f} Mapped) => @@ -50,7 +50,7 @@ extension [A](xs: {*} LazyList[A]) def tail: {xs, f} LazyList[B] = xs.tail.map(f) // error new Mapped - def map5[B](f: {*} A => B): LazyList[B] = + def map5[B](f: A => B): LazyList[B] = class Mapped extends LazyList[B]: this: ({xs, f} Mapped) => diff --git a/tests/neg-custom-args/captures/lazyref.check b/tests/neg-custom-args/captures/lazyref.check index 2affed020dec..e4e06f8c52cb 100644 --- a/tests/neg-custom-args/captures/lazyref.check +++ b/tests/neg-custom-args/captures/lazyref.check @@ -1,28 +1,28 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:19:28 -------------------------------------- 19 | val ref1c: LazyRef[Int] = ref1 // error | ^^^^ - | Found: (ref1 : {cap1} LazyRef[Int]{elem: {cap1} () => Int}) + | Found: (ref1 : {cap1} LazyRef[Int]{elem: {cap1} () -> Int}) | Required: LazyRef[Int] longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:21:35 -------------------------------------- 21 | val ref2c: {cap2} LazyRef[Int] = ref2 // error | ^^^^ - | Found: (ref2 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Found: (ref2 : {cap2, ref1} LazyRef[Int]{elem: {*} () -> Int}) | Required: {cap2} LazyRef[Int] longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:23:35 -------------------------------------- 23 | val ref3c: {ref1} LazyRef[Int] = ref3 // error | ^^^^ - | Found: (ref3 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Found: (ref3 : {cap2, ref1} LazyRef[Int]{elem: {*} () -> Int}) | Required: {ref1} LazyRef[Int] longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:25:35 -------------------------------------- 25 | val ref4c: {cap1} LazyRef[Int] = ref4 // error | ^^^^ - | Found: (ref4 : {cap2, cap1} LazyRef[Int]{elem: {*} () => Int}) + | Found: (ref4 : {cap2, cap1} LazyRef[Int]{elem: {*} () -> Int}) | Required: {cap1} LazyRef[Int] longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazyref.scala b/tests/neg-custom-args/captures/lazyref.scala index 1002f9685675..2b278fd51a43 100644 --- a/tests/neg-custom-args/captures/lazyref.scala +++ b/tests/neg-custom-args/captures/lazyref.scala @@ -1,15 +1,15 @@ class CC type Cap = {*} CC -class LazyRef[T](val elem: {*} () => T): +class LazyRef[T](val elem: () => T): val get = elem - def map[U](f: {*} T => U): {f, this} LazyRef[U] = + def map[U](f: T => U): {f, this} LazyRef[U] = new LazyRef(() => f(elem())) -def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = +def map[A, B](ref: {*} LazyRef[A], f: A => B): {f, ref} LazyRef[B] = new LazyRef(() => f(ref.elem())) -def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = +def mapc[A, B]: (ref: {*} LazyRef[A], f: A => B) -> {f, ref} LazyRef[B] = (ref1, f1) => map[A, B](ref1, f1) def test(cap1: Cap, cap2: Cap) = diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check index bd95835c6525..a2fe96016b80 100644 --- a/tests/neg-custom-args/captures/try.check +++ b/tests/neg-custom-args/captures/try.check @@ -1,8 +1,8 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------ -28 | val b = handle[Exception, () => Nothing] { // error +28 | val b = handle[Exception, () -> Nothing] { // error | ^ - | Found: ? (x: CanThrow[Exception]) => {x} () => ? Nothing - | Required: CanThrow[Exception] => () => Nothing + | Found: ? (x: CanThrow[Exception]) -> {x} () -> ? Nothing + | Required: CanThrow[Exception] => () -> Nothing 29 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) 30 | } { @@ -14,12 +14,12 @@ longer explanation available when compiling with `-explain` -- Error: tests/neg-custom-args/captures/try.scala:34:11 --------------------------------------------------------------- 34 | val xx = handle { // error | ^^^^^^ - | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any) | - | The inferred arguments are: [? Exception, {*} () => Int] + | The inferred arguments are: [? Exception, {*} () -> Int] -- Error: tests/neg-custom-args/captures/try.scala:46:13 --------------------------------------------------------------- 46 |val global = handle { // error | ^^^^^^ - | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any) | - | The inferred arguments are: [? Exception, {*} () => Int] + | The inferred arguments are: [? Exception, {*} () -> Int] diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala index 804a16192be0..b128f82a2a3c 100644 --- a/tests/neg-custom-args/captures/try.scala +++ b/tests/neg-custom-args/captures/try.scala @@ -25,7 +25,7 @@ def test = (ex: Exception) => ??? } - val b = handle[Exception, () => Nothing] { // error + val b = handle[Exception, () -> Nothing] { // error (x: CanThrow[Exception]) => () => raise(new Exception)(using x) } { (ex: Exception) => ??? diff --git a/tests/neg-custom-args/captures/vars.check b/tests/neg-custom-args/captures/vars.check index 4eab5b6b2b3a..0df38b918862 100644 --- a/tests/neg-custom-args/captures/vars.check +++ b/tests/neg-custom-args/captures/vars.check @@ -1,21 +1,21 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/vars.scala:11:24 ----------------------------------------- -11 | val z2c: () => Unit = z2 // error +11 | val z2c: () -> Unit = z2 // error | ^^ - | Found: (z2 : {x, cap1} () => Unit) - | Required: () => Unit + | Found: (z2 : {x, cap1} () -> Unit) + | Required: () -> Unit longer explanation available when compiling with `-explain` --- Error: tests/neg-custom-args/captures/vars.scala:13:10 -------------------------------------------------------------- -13 | var a: {*} String => String = f // error - | ^^^^^^^^^^^^^^^^^^^ - | type of mutable variable box {*} String => String is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/vars.scala:13:16 -------------------------------------------------------------- +13 | var a: String => String = f // error + | ^^^^^^^^^^^^^^^^ + | type of mutable variable String => String is not allowed to capture the universal capability (* : Any) -- Error: tests/neg-custom-args/captures/vars.scala:14:9 --------------------------------------------------------------- -14 | var b: List[{*} String => String] = Nil // error - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ - |type of mutable variable List[box {*} String => String] is not allowed to capture the universal capability (* : Any) +14 | var b: List[String => String] = Nil // error + | ^^^^^^^^^^^^^^^^^^^^^^ + | type of mutable variable List[String => String] is not allowed to capture the universal capability (* : Any) -- Error: tests/neg-custom-args/captures/vars.scala:29:2 --------------------------------------------------------------- 29 | local { cap3 => // error | ^^^^^ - |inferred type argument {*} (x$0: ? String) => ? String is not allowed to capture the universal capability (* : Any) + |inferred type argument {*} (x$0: ? String) -> ? String is not allowed to capture the universal capability (* : Any) | - |The inferred arguments are: [{*} (x$0: ? String) => ? String] + |The inferred arguments are: [{*} (x$0: ? String) -> ? String] diff --git a/tests/neg-custom-args/captures/vars.scala b/tests/neg-custom-args/captures/vars.scala index 4a58f79932b3..e85bcbe2db04 100644 --- a/tests/neg-custom-args/captures/vars.scala +++ b/tests/neg-custom-args/captures/vars.scala @@ -6,12 +6,12 @@ def test(cap1: Cap, cap2: Cap) = var x = f val y = x val z = () => if x("") == "" then "a" else "b" - val zc: {cap1} () => String = z + val zc: {cap1} () -> String = z val z2 = () => { x = identity } - val z2c: () => Unit = z2 // error + val z2c: () -> Unit = z2 // error - var a: {*} String => String = f // error - var b: List[{*} String => String] = Nil // error + var a: String => String = f // error + var b: List[String => String] = Nil // error def scope = val cap3: Cap = CC() @@ -22,9 +22,9 @@ def test(cap1: Cap, cap2: Cap) = g val s = scope - val sc: {*} String => String = scope + val sc: String => String = scope - def local[T](op: Cap => T): T = op(CC()) + def local[T](op: Cap -> T): T = op(CC()) local { cap3 => // error def g(x: String): String = if cap3 == cap3 then "" else "a" @@ -32,7 +32,7 @@ def test(cap1: Cap, cap2: Cap) = } class Ref: - var elem: {cap1} String => String = null + var elem: {cap1} String -> String = null val r = Ref() r.elem = f diff --git a/tests/pos-custom-args/captures/bounded.scala b/tests/pos-custom-args/captures/bounded.scala index fad0b50c2137..85c1a67387b5 100644 --- a/tests/pos-custom-args/captures/bounded.scala +++ b/tests/pos-custom-args/captures/bounded.scala @@ -9,6 +9,6 @@ def test(c: Cap) = def f(x: Int): Int = if c == c then x else 0 val b = new B(f) val r1 = b.elem - val r1c: {c} Int => Int = r1 + val r1c: {c} Int -> Int = r1 val r2 = b.lateElem - val r2c: {c} () => {c} Int => Int = r2 \ No newline at end of file + val r2c: {c} () -> {c} Int -> Int = r2 \ No newline at end of file diff --git a/tests/pos-custom-args/captures/boxmap-paper.scala b/tests/pos-custom-args/captures/boxmap-paper.scala index ed8c648526d1..aff4c38e1b9d 100644 --- a/tests/pos-custom-args/captures/boxmap-paper.scala +++ b/tests/pos-custom-args/captures/boxmap-paper.scala @@ -1,19 +1,18 @@ -infix type ==> [A, B] = {*} (A => B) -type Cell[+T] = [K] => (T ==> K) => K +type Cell[+T] = [K] -> (T => K) -> K def cell[T](x: T): Cell[T] = - [K] => (k: T ==> K) => k(x) + [K] => (k: T => K) => k(x) def get[T](c: Cell[T]): T = c[T](identity) -def map[A, B](c: Cell[A])(f: A ==> B): Cell[B] +def map[A, B](c: Cell[A])(f: A => B): Cell[B] = c[Cell[B]]((x: A) => cell(f(x))) -def pureMap[A, B](c: Cell[A])(f: A => B): Cell[B] +def pureMap[A, B](c: Cell[A])(f: A -> B): Cell[B] = c[Cell[B]]((x: A) => cell(f(x))) -def lazyMap[A, B](c: Cell[A])(f: A ==> B): {f} () => Cell[B] +def lazyMap[A, B](c: Cell[A])(f: A => B): {f} () -> Cell[B] = () => c[Cell[B]]((x: A) => cell(f(x))) trait IO: @@ -21,17 +20,17 @@ trait IO: def test(io: {*} IO) = - val loggedOne: {io} () => Int = () => { io.print("1"); 1 } + val loggedOne: {io} () -> Int = () => { io.print("1"); 1 } - val c: Cell[{io} () => Int] - = cell[{io} () => Int](loggedOne) + val c: Cell[{io} () -> Int] + = cell[{io} () -> Int](loggedOne) - val g = (f: {io} () => Int) => + val g = (f: {io} () -> Int) => val x = f(); io.print(" + ") val y = f(); io.print(s" = ${x + y}") - val r = lazyMap[{io} () => Int, Unit](c)(f => g(f)) - val r2 = lazyMap[{io} () => Int, Unit](c)(g) + val r = lazyMap[{io} () -> Int, Unit](c)(f => g(f)) + val r2 = lazyMap[{io} () -> Int, Unit](c)(g) val r3 = lazyMap(c)(g) val _ = r() val _ = r2() diff --git a/tests/pos-custom-args/captures/boxmap.scala b/tests/pos-custom-args/captures/boxmap.scala index a0dcade2b179..003e46804a9d 100644 --- a/tests/pos-custom-args/captures/boxmap.scala +++ b/tests/pos-custom-args/captures/boxmap.scala @@ -1,20 +1,18 @@ type Top = Any @retains(*) -infix type ==> [A, B] = (A => B) @retains(*) - -type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) +type Box[+T <: Top] = ([K <: Top] -> (T => K) -> K) def box[T <: Top](x: T): Box[T] = - [K <: Top] => (k: T ==> K) => k(x) + [K <: Top] => (k: T => K) => k(x) -def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = +def map[A <: Top, B <: Top](b: Box[A])(f: A => B): Box[B] = b[Box[B]]((x: A) => box(f(x))) -def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): (() => Box[B]) @retains(f) = +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A => B): (() -> Box[B]) @retains(f) = () => b[Box[B]]((x: A) => box(f(x))) def test[A <: Top, B <: Top] = - def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B) = + def lazymap[A <: Top, B <: Top](b: Box[A])(f: A => B) = () => b[Box[B]]((x: A) => box(f(x))) - val x: (b: Box[A]) => (f: A ==> B) => (() => Box[B]) @retains(f) = lazymap[A, B] + val x: (b: Box[A]) -> (f: A => B) -> (() -> Box[B]) @retains(f) = lazymap[A, B] () diff --git a/tests/pos-custom-args/captures/capt-capability.scala b/tests/pos-custom-args/captures/capt-capability.scala index 41da15d288f1..9990542a199d 100644 --- a/tests/pos-custom-args/captures/capt-capability.scala +++ b/tests/pos-custom-args/captures/capt-capability.scala @@ -1,15 +1,15 @@ import annotation.capability @capability class Cap -def f1(c: Cap): {c} () => c.type = () => c // ok +def f1(c: Cap): {c} () -> c.type = () => c // ok def f2: Int = - val g: {*} Boolean => Int = ??? + val g: Boolean => Int = ??? val x = g(true) x def f3: Int = - def g: {*} Boolean => Int = ??? + def g: Boolean => Int = ??? def h = g val x = g.apply(true) x @@ -17,10 +17,10 @@ def f3: Int = def foo() = val x: Cap = ??? val y: Cap = x - val x2: {x} () => Cap = ??? - val y2: {x} () => Cap = x2 + val x2: {x} () -> Cap = ??? + val y2: {x} () -> Cap = x2 - val z1: {*} () => Cap = f1(x) + val z1: () => Cap = f1(x) def h[X](a: X)(b: X) = a val z2 = diff --git a/tests/pos-custom-args/captures/capt-depfun.scala b/tests/pos-custom-args/captures/capt-depfun.scala index 6b99eff32692..072eaefd3e78 100644 --- a/tests/pos-custom-args/captures/capt-depfun.scala +++ b/tests/pos-custom-args/captures/capt-depfun.scala @@ -1,18 +1,18 @@ class C type Cap = C @retains(*) -type T = (x: Cap) => String @retains(x) +type T = (x: Cap) -> String @retains(x) -val aa: ((x: Cap) => String @retains(x)) = (x: Cap) => "" +val aa: ((x: Cap) -> String @retains(x)) = (x: Cap) => "" def f(y: Cap, z: Cap): String @retains(*) = - val a: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + val a: ((x: Cap) -> String @retains(x)) = (x: Cap) => "" val b = a(y) val c: String @retains(y) = b def g(): C @retains(y, z) = ??? val d = a(g()) - val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? - val bc: (({y} String) => {y} String) = ac(y) - val dc: (String => {y, z} String) = ac(g()) + val ac: ((x: Cap) -> String @retains(x) -> String @retains(x)) = ??? + val bc: (({y} String) -> {y} String) = ac(y) + val dc: (String -> {y, z} String) = ac(g()) c diff --git a/tests/pos-custom-args/captures/capt-depfun2.scala b/tests/pos-custom-args/captures/capt-depfun2.scala index 17f98b4a1554..98ee9dbfdc6b 100644 --- a/tests/pos-custom-args/captures/capt-depfun2.scala +++ b/tests/pos-custom-args/captures/capt-depfun2.scala @@ -3,6 +3,6 @@ type Cap = C @retains(*) def f(y: Cap, z: Cap) = def g(): C @retains(y, z) = ??? - val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val ac: ((x: Cap) -> Array[String @retains(x)]) = ??? val dc: Array[? >: String <: {y, z} String] = ac(g()) // needs to be inferred val ec = ac(y) diff --git a/tests/pos-custom-args/captures/capt-test.scala b/tests/pos-custom-args/captures/capt-test.scala index f40bd2ff1746..6ee0d2a4d9f4 100644 --- a/tests/pos-custom-args/captures/capt-test.scala +++ b/tests/pos-custom-args/captures/capt-test.scala @@ -2,7 +2,7 @@ abstract class LIST[+T]: def isEmpty: Boolean def head: T def tail: LIST[T] - def map[U](f: {*} T => U): LIST[U] = + def map[U](f: T => U): LIST[U] = if isEmpty then NIL else CONS(f(head), tail.map(f)) @@ -15,7 +15,7 @@ object NIL extends LIST[Nothing]: def head = ??? def tail = ??? -def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = +def map[A, B](f: A => B)(xs: LIST[A]): LIST[B] = xs.map(f) class C @@ -29,7 +29,7 @@ def test(c: Cap, d: Cap) = val zs = val z = g CONS(z, ys) - val zsc: LIST[{d, y} Cap => Unit] = zs + val zsc: LIST[{d, y} Cap -> Unit] = zs val a4 = zs.map(identity) - val a4c: LIST[{d, y} Cap => Unit] = a4 + val a4c: LIST[{d, y} Cap -> Unit] = a4 diff --git a/tests/pos-custom-args/captures/capt1.scala b/tests/pos-custom-args/captures/capt1.scala index 14c0855544d4..e8e217435f96 100644 --- a/tests/pos-custom-args/captures/capt1.scala +++ b/tests/pos-custom-args/captures/capt1.scala @@ -1,14 +1,14 @@ class C type Cap = {*} C -def f1(c: Cap): {c} () => c.type = () => c // ok +def f1(c: Cap): {c} () -> c.type = () => c // ok def f2: Int = - val g: {*} Boolean => Int = ??? + val g: {*} Boolean -> Int = ??? val x = g(true) x def f3: Int = - def g: {*} Boolean => Int = ??? + def g: Boolean => Int = ??? def h = g val x = g.apply(true) x @@ -16,10 +16,10 @@ def f3: Int = def foo() = val x: {*} C = ??? val y: {x} C = x - val x2: {x} () => C = ??? - val y2: {x} () => {x} C = x2 + val x2: {x} () -> C = ??? + val y2: {x} () -> {x} C = x2 - val z1: {*} () => Cap = f1(x) + val z1: () => Cap = f1(x) def h[X](a: X)(b: X) = a val z2 = diff --git a/tests/pos-custom-args/captures/capt2.scala b/tests/pos-custom-args/captures/capt2.scala index e3d4cd67b30c..11bb2d5eb7b5 100644 --- a/tests/pos-custom-args/captures/capt2.scala +++ b/tests/pos-custom-args/captures/capt2.scala @@ -9,12 +9,12 @@ def test1() = def test2() = val x: Cap = C() val y = () => { x; () } - def z: (() => Unit) @retains(x) = y - z: (() => Unit) @retains(x) - def z2: (() => Unit) @retains(y) = y - z2: (() => Unit) @retains(y) - val p: {*} () => String = () => "abc" + def z: (() -> Unit) @retains(x) = y + z: (() -> Unit) @retains(x) + def z2: (() -> Unit) @retains(y) = y + z2: (() -> Unit) @retains(y) + val p: {*} () -> String = () => "abc" val q: {p} C = ??? - p: ({p} () => String) + p: ({p} () -> String) diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala index eedc95554b17..7bce1ea8387e 100644 --- a/tests/pos-custom-args/captures/cc-expand.scala +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -8,14 +8,14 @@ object Test: def test(ct: CT, dt: CT) = - def x0: A => {ct} B = ??? + def x0: A -> {ct} B = ??? - def x1: A => B @retains(ct) = ??? - def x2: A => B => C @retains(ct) = ??? - def x3: A => () => B => C @retains(ct) = ??? + def x1: A -> B @retains(ct) = ??? + def x2: A -> B -> C @retains(ct) = ??? + def x3: A -> () -> B -> C @retains(ct) = ??? - def x4: (x: A @retains(ct)) => B => C = ??? + def x4: (x: A @retains(ct)) -> B -> C = ??? - def x5: A => (x: B @retains(ct)) => () => C @retains(dt) = ??? - def x6: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x, dt) = ??? - def x7: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x) = ??? \ No newline at end of file + def x5: A -> (x: B @retains(ct)) -> () -> C @retains(dt) = ??? + def x6: A -> (x: B @retains(ct)) -> (() -> C @retains(dt)) @retains(x, dt) = ??? + def x7: A -> (x: B @retains(ct)) -> (() -> C @retains(dt)) @retains(x) = ??? \ No newline at end of file diff --git a/tests/pos-custom-args/captures/impurefun.scala b/tests/pos-custom-args/captures/impurefun.scala new file mode 100644 index 000000000000..6e31008fe54a --- /dev/null +++ b/tests/pos-custom-args/captures/impurefun.scala @@ -0,0 +1,8 @@ +object Test: + + val f: ImpureFunction1[Int, Int] = (x: Int) => x + 1 + + val g: Int -> Int = (x: Int) => x + 1 + + val h: Int => Int = (x: Int) => x + 1 + diff --git a/tests/pos-custom-args/captures/lazylists-mono.scala b/tests/pos-custom-args/captures/lazylists-mono.scala index 82c44abf703a..44ab36ded6a2 100644 --- a/tests/pos-custom-args/captures/lazylists-mono.scala +++ b/tests/pos-custom-args/captures/lazylists-mono.scala @@ -6,21 +6,21 @@ type Cap = {*} CC def test(E: Cap) = trait LazyList[+A]: - protected def contents: {E} () => (A, {E} LazyList[A]) + protected def contents: {E} () -> (A, {E} LazyList[A]) def isEmpty: Boolean def head: A = contents()._1 def tail: {E} LazyList[A] = contents()._2 - class LazyCons[+A](override val contents: {E} () => (A, {E} LazyList[A])) + class LazyCons[+A](override val contents: {E} () -> (A, {E} LazyList[A])) extends LazyList[A]: def isEmpty: Boolean = false object LazyNil extends LazyList[Nothing]: - def contents: {E} () => (Nothing, LazyList[Nothing]) = ??? + def contents: {E} () -> (Nothing, LazyList[Nothing]) = ??? def isEmpty: Boolean = true extension [A](xs: {E} LazyList[A]) - def map[B](f: {E} A => B): {E} LazyList[B] = + def map[B](f: {E} A -> B): {E} LazyList[B] = if xs.isEmpty then LazyNil else val cons = () => (f(xs.head), xs.tail.map(f)) diff --git a/tests/pos-custom-args/captures/lazylists.scala b/tests/pos-custom-args/captures/lazylists.scala index 17d5f8546edc..bf3e0300b5b5 100644 --- a/tests/pos-custom-args/captures/lazylists.scala +++ b/tests/pos-custom-args/captures/lazylists.scala @@ -14,7 +14,7 @@ object LazyNil extends LazyList[Nothing]: def tail = ??? extension [A](xs: {*} LazyList[A]) - def map[B](f: {*} A => B): {xs, f} LazyList[B] = + def map[B](f: A => B): {xs, f} LazyList[B] = final class Mapped extends LazyList[B]: this: ({xs, f} Mapped) => diff --git a/tests/pos-custom-args/captures/lazylists1.scala b/tests/pos-custom-args/captures/lazylists1.scala index 4c8006fb0e29..7fbdde87ad9b 100644 --- a/tests/pos-custom-args/captures/lazylists1.scala +++ b/tests/pos-custom-args/captures/lazylists1.scala @@ -13,23 +13,23 @@ object LazyNil extends LazyList[Nothing]: def head = ??? def tail = ??? -final class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: +final class LazyCons[+T](val x: T, val xs: Int => {*} LazyList[T]) extends LazyList[T]: this: ({*} LazyList[T]) => def isEmpty = false def head = x - def tail: {this} LazyList[T] = xs() + def tail: {this} LazyList[T] = xs(0) extension [A](xs: {*} LazyList[A]) - def map[B](f: {*} A => B): {xs, f} LazyList[B] = + def map[B](f: A => B): {xs, f} LazyList[B] = if xs.isEmpty then LazyNil - else LazyCons(f(xs.head), () => xs.tail.map(f)) + else LazyCons(f(xs.head), x => xs.tail.map(f)) def test(cap1: Cap, cap2: Cap) = def f(x: String): String = if cap1 == cap1 then "" else "a" def g(x: String): String = if cap2 == cap2 then "" else "a" - val xs = LazyCons("", () => if f("") == f("") then LazyNil else LazyNil) + val xs = new LazyCons("", x => if f("") == f("") then LazyNil else LazyNil) val xsc: {cap1} LazyList[String] = xs val ys = xs.map(g) val ysc: {cap1, cap2} LazyList[String] = ys diff --git a/tests/pos-custom-args/captures/lazyref.scala b/tests/pos-custom-args/captures/lazyref.scala index 39748b00506b..2ab770178a16 100644 --- a/tests/pos-custom-args/captures/lazyref.scala +++ b/tests/pos-custom-args/captures/lazyref.scala @@ -1,15 +1,14 @@ -class CC -type Cap = {*} CC +@annotation.capability class Cap -class LazyRef[T](val elem: {*} () => T): +class LazyRef[T](val elem: () => T): val get = elem - def map[U](f: {*} T => U): {f, this} LazyRef[U] = + def map[U](f: T => U): {f, this} LazyRef[U] = new LazyRef(() => f(elem())) -def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = +def map[A, B](ref: {*} LazyRef[A], f: A => B): {f, ref} LazyRef[B] = new LazyRef(() => f(ref.elem())) -def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = +def mapc[A, B]: (ref: {*} LazyRef[A], f: A => B) => {f, ref} LazyRef[B] = (ref1, f1) => map[A, B](ref1, f1) def test(cap1: Cap, cap2: Cap) = diff --git a/tests/pos-custom-args/captures/list-encoding.scala b/tests/pos-custom-args/captures/list-encoding.scala index 74bc8bd2b099..59ae61273af7 100644 --- a/tests/pos-custom-args/captures/list-encoding.scala +++ b/tests/pos-custom-args/captures/list-encoding.scala @@ -3,10 +3,10 @@ package listEncoding class Cap type Op[T, C] = - {*} (v: T) => {*} (s: C) => C + (v: T) => (s: C) => C type List[T] = - [C] => (op: Op[T, C]) => {op} (s: C) => C + [C] -> (op: Op[T, C]) -> {op} (s: C) -> C def nil[T]: List[T] = [C] => (op: Op[T, C]) => (s: C) => s diff --git a/tests/pos-custom-args/captures/lists.scala b/tests/pos-custom-args/captures/lists.scala index 139f885ec87a..f52727af7b94 100644 --- a/tests/pos-custom-args/captures/lists.scala +++ b/tests/pos-custom-args/captures/lists.scala @@ -2,7 +2,7 @@ abstract class LIST[+T]: def isEmpty: Boolean def head: T def tail: LIST[T] - def map[U](f: {*} T => U): LIST[U] = + def map[U](f: {*} T -> U): LIST[U] = if isEmpty then NIL else CONS(f(head), tail.map(f)) @@ -15,11 +15,10 @@ object NIL extends LIST[Nothing]: def head = ??? def tail = ??? -def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = +def map[A, B](f: A => B)(xs: LIST[A]): LIST[B] = xs.map(f) -class C -type Cap = {*} C +@annotation.capability class Cap def test(c: Cap, d: Cap, e: Cap) = def f(x: Cap): Unit = if c == x then () @@ -29,63 +28,63 @@ def test(c: Cap, d: Cap, e: Cap) = val zs = val z = g CONS(z, ys) - val zsc: LIST[{d, y} Cap => Unit] = zs + val zsc: LIST[{d, y} Cap -> Unit] = zs val z1 = zs.head - val z1c: {y, d} Cap => Unit = z1 + val z1c: {y, d} Cap -> Unit = z1 val ys1 = zs.tail val y1 = ys1.head def m1[A, B] = - (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + (f: A => B) => (xs: LIST[A]) => xs.map(f) - def m1c: (f: {*} String => Int) => {f} LIST[String] => LIST[Int] = m1[String, Int] + def m1c: (f: String => Int) -> {f} LIST[String] -> LIST[Int] = m1[String, Int] def m2 = [A, B] => - (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + (f: A => B) => (xs: LIST[A]) => xs.map(f) - def m2c: [A, B] => (f: {*} A => B) => {f} LIST[A] => LIST[B] = m2 + def m2c: [A, B] -> (f: A => B) -> {f} LIST[A] -> LIST[B] = m2 def eff[A](x: A) = if x == e then x else x val eff2 = [A] => (x: A) => if x == e then x else x - val a0 = identity[{d, y} Cap => Unit] - val a0c: ({d, y} Cap => Unit) => {d, y} Cap => Unit = a0 - val a1 = zs.map[{d, y} Cap => Unit](a0) - val a1c: LIST[{d, y} Cap => Unit] = a1 - val a2 = zs.map[{d, y} Cap => Unit](identity[{d, y} Cap => Unit]) - val a2c: LIST[{d, y} Cap => Unit] = a2 - val a3 = zs.map(identity[{d, y} Cap => Unit]) - val a3c: LIST[{d, y} Cap => Unit] = a3 + val a0 = identity[{d, y} Cap -> Unit] + val a0c: ({d, y} Cap -> Unit) -> {d, y} Cap -> Unit = a0 + val a1 = zs.map[{d, y} Cap -> Unit](a0) + val a1c: LIST[{d, y} Cap -> Unit] = a1 + val a2 = zs.map[{d, y} Cap -> Unit](identity[{d, y} Cap -> Unit]) + val a2c: LIST[{d, y} Cap -> Unit] = a2 + val a3 = zs.map(identity[{d, y} Cap -> Unit]) + val a3c: LIST[{d, y} Cap -> Unit] = a3 val a4 = zs.map(identity) - val a4c: LIST[{d, c} Cap => Unit] = a4 - val a5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) - val a5c: LIST[{d, c} Cap => Unit] = a5 - val a6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) - val a6c: LIST[{d, c} Cap => Unit] = a6 + val a4c: LIST[{d, c} Cap -> Unit] = a4 + val a5 = map[{d, y} Cap -> Unit, {d, y} Cap -> Unit](identity)(zs) + val a5c: LIST[{d, c} Cap -> Unit] = a5 + val a6 = m1[{d, y} Cap -> Unit, {d, y} Cap -> Unit](identity)(zs) + val a6c: LIST[{d, c} Cap -> Unit] = a6 - val b0 = eff[{d, y} Cap => Unit] - val b0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = b0 - val b1 = zs.map[{d, y} Cap => Unit](a0) - val b1c: {e} LIST[{d, y} Cap => Unit] = b1 - val b2 = zs.map[{d, y} Cap => Unit](eff[{d, y} Cap => Unit]) - val b2c: {e} LIST[{d, y} Cap => Unit] = b2 - val b3 = zs.map(eff[{d, y} Cap => Unit]) - val b3c: {e} LIST[{d, y} Cap => Unit] = b3 + val b0 = eff[{d, y} Cap -> Unit] + val b0c: {e} ({d, y} Cap -> Unit) -> {d, y} Cap -> Unit = b0 + val b1 = zs.map[{d, y} Cap -> Unit](a0) + val b1c: {e} LIST[{d, y} Cap -> Unit] = b1 + val b2 = zs.map[{d, y} Cap -> Unit](eff[{d, y} Cap -> Unit]) + val b2c: {e} LIST[{d, y} Cap -> Unit] = b2 + val b3 = zs.map(eff[{d, y} Cap -> Unit]) + val b3c: {e} LIST[{d, y} Cap -> Unit] = b3 val b4 = zs.map(eff) - val b4c: {e} LIST[{d, c} Cap => Unit] = b4 - val b5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) - val b5c: {e} LIST[{d, c} Cap => Unit] = b5 - val b6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) - val b6c: {e} LIST[{d, c} Cap => Unit] = b6 + val b4c: {e} LIST[{d, c} Cap -> Unit] = b4 + val b5 = map[{d, y} Cap -> Unit, {d, y} Cap -> Unit](eff)(zs) + val b5c: {e} LIST[{d, c} Cap -> Unit] = b5 + val b6 = m1[{d, y} Cap -> Unit, {d, y} Cap -> Unit](eff)(zs) + val b6c: {e} LIST[{d, c} Cap -> Unit] = b6 - val c0 = eff2[{d, y} Cap => Unit] - val c0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = c0 - val c1 = zs.map[{d, y} Cap => Unit](a0) - val c1c: {e} LIST[{d, y} Cap => Unit] = c1 - val c2 = zs.map[{d, y} Cap => Unit](eff2[{d, y} Cap => Unit]) - val c2c: {e} LIST[{d, y} Cap => Unit] = c2 - val c3 = zs.map(eff2[{d, y} Cap => Unit]) - val c3c: {e} LIST[{d, y} Cap => Unit] = c3 + val c0 = eff2[{d, y} Cap -> Unit] + val c0c: {e} ({d, y} Cap -> Unit) -> {d, y} Cap -> Unit = c0 + val c1 = zs.map[{d, y} Cap -> Unit](a0) + val c1c: {e} LIST[{d, y} Cap -> Unit] = c1 + val c2 = zs.map[{d, y} Cap -> Unit](eff2[{d, y} Cap -> Unit]) + val c2c: {e} LIST[{d, y} Cap -> Unit] = c2 + val c3 = zs.map(eff2[{d, y} Cap -> Unit]) + val c3c: {e} LIST[{d, y} Cap -> Unit] = c3 diff --git a/tests/pos-custom-args/captures/pairs.scala b/tests/pos-custom-args/captures/pairs.scala index 4f23a086a075..14d484ff21b1 100644 --- a/tests/pos-custom-args/captures/pairs.scala +++ b/tests/pos-custom-args/captures/pairs.scala @@ -1,6 +1,5 @@ -class C -type Cap = {*} C +@annotation.capability class Cap object Generic: @@ -13,13 +12,13 @@ object Generic: def g(x: Cap): Unit = if d == x then () val p = Pair(f, g) val x1 = p.fst - val x1c: {c} Cap => Unit = x1 + val x1c: {c} Cap -> Unit = x1 val y1 = p.snd - val y1c: {d} Cap => Unit = y1 + val y1c: {d} Cap -> Unit = y1 object Monomorphic: - class Pair(val x: {*} Cap => Unit, val y: {*} Cap => Unit): + class Pair(val x: Cap => Unit, val y: {*} Cap -> Unit): def fst = x def snd = y @@ -28,6 +27,6 @@ object Monomorphic: def g(x: Cap): Unit = if d == x then () val p = Pair(f, g) val x1 = p.fst - val x1c: {c} Cap => Unit = x1 + val x1c: {c} Cap -> Unit = x1 val y1 = p.snd - val y1c: {d} Cap => Unit = y1 + val y1c: {d} Cap -> Unit = y1 diff --git a/tests/pos-custom-args/captures/try.scala b/tests/pos-custom-args/captures/try.scala index a50eeabfb3a3..14773cd5be0f 100644 --- a/tests/pos-custom-args/captures/try.scala +++ b/tests/pos-custom-args/captures/try.scala @@ -3,7 +3,7 @@ import language.experimental.erasedDefinitions class CT[E <: Exception] type CanThrow[E <: Exception] = CT[E] @retains(*) -infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?-> R class Fail extends Exception @@ -12,7 +12,7 @@ def raise[E <: Exception](e: E): Nothing throws E = throw e def foo(x: Boolean): Int throws Fail = if x then 1 else raise(Fail()) -def handle[E <: Exception, R](op: (erased CanThrow[E]) => R)(handler: E => R): R = +def handle[E <: Exception, R](op: (erased CanThrow[E]) -> R)(handler: E -> R): R = erased val x: CanThrow[E] = ??? try op(x) catch case ex: E => handler(ex) diff --git a/tests/pos-custom-args/captures/try3.scala b/tests/pos-custom-args/captures/try3.scala index b29ad2d4b352..b8937bec00f3 100644 --- a/tests/pos-custom-args/captures/try3.scala +++ b/tests/pos-custom-args/captures/try3.scala @@ -2,8 +2,7 @@ import language.experimental.erasedDefinitions import annotation.capability import java.io.IOException -class CT[-E] // variance is needed for correct rechecking inference -type CanThrow[E] = {*} CT[E] +@annotation.capability class CanThrow[-E] def handle[E <: Exception, T](op: CanThrow[E] ?=> T)(handler: E => T): T = val x: CanThrow[E] = ??? @@ -14,7 +13,7 @@ def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = throw ex def test1: Int = - def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + def f(a: Boolean): Boolean -> CanThrow[IOException] ?-> Int = handle { if !a then raise(IOException()) (b: Boolean) => (_: CanThrow[IOException]) ?=> diff --git a/tests/pos-custom-args/captures/vars.scala b/tests/pos-custom-args/captures/vars.scala index aca56c55f386..12721158a2bb 100644 --- a/tests/pos-custom-args/captures/vars.scala +++ b/tests/pos-custom-args/captures/vars.scala @@ -1,18 +1,17 @@ -class CC -type Cap = {*} CC +@annotation.capability class Cap def test(cap1: Cap, cap2: Cap) = def f(x: String): String = if cap1 == cap1 then "" else "a" var x = f val y = x val z = () => if x("") == "" then "a" else "b" - val zc: {cap1} () => String = z + val zc: {cap1} () -> String = z val z2 = () => { x = identity } - val z2c: {cap1} () => Unit = z2 + val z2c: {cap1} () -> Unit = z2 class Ref: - var elem: {cap1} String => String = null + var elem: {cap1} String -> String = null val r = Ref() r.elem = f - val fc: {cap1} String => String = r.elem + val fc: {cap1} String -> String = r.elem diff --git a/tests/pos/i12723.scala b/tests/pos/i12723.scala index d1cab3ede638..022a3a458f04 100644 --- a/tests/pos/i12723.scala +++ b/tests/pos/i12723.scala @@ -1,10 +1,10 @@ class Fun[|*|[_, _]] { - enum ->[A, B] { - case BiId[X, Y]() extends ((X |*| Y) -> (X |*| Y)) + enum -->[A, B] { + case BiId[X, Y]() extends ((X |*| Y) --> (X |*| Y)) } - def go[A, B](f: A -> B): Unit = + def go[A, B](f: A --> B): Unit = f match { - case ->.BiId() => () + case -->.BiId() => () } } diff --git a/tests/pos/impurefun.scala b/tests/pos/impurefun.scala new file mode 100644 index 000000000000..c9f4c54a0b90 --- /dev/null +++ b/tests/pos/impurefun.scala @@ -0,0 +1,4 @@ +object Test: + + val f: ImpureFunction1[Int, Int] = (x: Int) => x + 1 + From 9bc09ee3f71adb3e5b785fad77da0d73799e8e7c Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Fri, 21 Jan 2022 14:27:43 +0100 Subject: [PATCH 13/24] Updates for latest master - Fix rebase breakage - weaken test in TreePickler that was introduced in the meantime since the last rebase (this one needs follow up) - adapt to latest restrictions on rhs of erased definitions --- compiler/src/dotty/tools/dotc/ast/MainProxies.scala | 2 +- compiler/src/dotty/tools/dotc/core/Definitions.scala | 1 + compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala | 2 +- compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala | 7 ++----- tests/pos-custom-args/captures/byname.scala | 2 +- tests/pos-custom-args/captures/try.scala | 2 +- 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala index 2eafeca16e39..af7c86d1f604 100644 --- a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala +++ b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala @@ -4,7 +4,7 @@ package ast import core._ import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._ import StdNames.nme -import ast.Trees._ +import Trees._ /** Generate proxy classes for @main functions. * A function like diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 576bdf6d1b95..9b65ce971b64 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1349,6 +1349,7 @@ class Definitions { funTypeArray(funTypeIdx(isContextual, isErased, isImpure))(n).symbol @tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply) + @tu lazy val ContextFunction0_apply: Symbol = ContextFunction0.requiredMethod(nme.apply) @tu lazy val Function0: Symbol = FunctionSymbol(0) @tu lazy val Function1: Symbol = FunctionSymbol(1) diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 8f5910c3dd56..c83d65ff483d 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -202,7 +202,7 @@ class TreePickler(pickler: TastyPickler) { else if (tpe.prefix == NoPrefix) { writeByte(if (tpe.isType) TYPEREFdirect else TERMREFdirect) if !symRefs.contains(sym) && !sym.isPatternBound && !sym.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then - report.error(i"pickling reference to as yet undefined $tpe with symbol ${sym}", sym.srcPos) + report.log(i"pickling reference to as yet undefined $tpe with symbol ${sym}", sym.srcPos) pickleSymRef(sym) } else tpe.designator match { diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index e78776cb0158..3cdf39cb56e6 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -467,11 +467,8 @@ class CheckCaptures extends Recheck: recheckFinish(result, arg, pt) override def recheckApply(tree: Apply, pt: Type)(using Context): Type = - if tree.symbol == defn.cbnArg then - recheckByNameArg(tree.args(0), pt) - else - includeCallCaptures(tree.symbol, tree.srcPos) - super.recheckApply(tree, pt) + includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckApply(tree, pt) override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = val res = super.recheck(tree, pt) diff --git a/tests/pos-custom-args/captures/byname.scala b/tests/pos-custom-args/captures/byname.scala index 917154079b36..a3d80f31d579 100644 --- a/tests/pos-custom-args/captures/byname.scala +++ b/tests/pos-custom-args/captures/byname.scala @@ -5,6 +5,6 @@ class I def test(cap1: Cap, cap2: Cap): {cap1} I = def f() = if cap1 == cap1 then I() else I() - def h(x: => {cap1} I) = x + def h(x: /*=>*/ {cap1} I) = x // TODO: enable cbn h(f()) diff --git a/tests/pos-custom-args/captures/try.scala b/tests/pos-custom-args/captures/try.scala index 14773cd5be0f..73e8a1d27ec7 100644 --- a/tests/pos-custom-args/captures/try.scala +++ b/tests/pos-custom-args/captures/try.scala @@ -13,7 +13,7 @@ def foo(x: Boolean): Int throws Fail = if x then 1 else raise(Fail()) def handle[E <: Exception, R](op: (erased CanThrow[E]) -> R)(handler: E -> R): R = - erased val x: CanThrow[E] = ??? + erased val x: CanThrow[E] = ??? : CanThrow[E] try op(x) catch case ex: E => handler(ex) From e349accd907f4971115e7aca7d3e4970a4f75df1 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Thu, 23 Dec 2021 18:47:33 +0100 Subject: [PATCH 14/24] Shorthand for curried functions Propagate capture sets to the right in curried functions. Example: {x} A -> B -> C is a shorthand for {x} A -> {x} B -> C or: (x: {*} A) -> B -> C is a shorthand for (x: {*} A) -> {x} B -> C or: ({*} A) -> B -> C is a shorthand for (x$0: {*} A) -> {x$0} B -> C Also: allow empty capture sets in types This gives a more convenient override to disable capture set propagation in curried types than wrapping in a type alias. E.g. compare {x} A -> {} B -> C with {x} A -> Protect[B -> C] where type Protect[X] = X Also: refactoring to move setup code from Rechecker and CheckCaptures into a joint class cc.Setup. --- compiler/src/dotty/tools/dotc/cc/Setup.scala | 328 ++++++++++++++++++ .../tools/dotc/config/ScalaSettings.scala | 1 + .../dotty/tools/dotc/parsing/Parsers.scala | 16 +- .../tools/dotc/printing/RefinedPrinter.scala | 2 +- .../dotty/tools/dotc/transform/Recheck.scala | 163 +++------ .../tools/dotc/typer/CheckCaptures.scala | 214 ++---------- .../captures/curried-simplified.check | 42 +++ .../captures/curried-simplified.scala | 21 ++ .../captures/capt-depfun.scala | 4 +- 9 files changed, 468 insertions(+), 323 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/cc/Setup.scala create mode 100644 tests/neg-custom-args/captures/curried-simplified.check create mode 100644 tests/neg-custom-args/captures/curried-simplified.scala diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala new file mode 100644 index 000000000000..b5ccb002e209 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -0,0 +1,328 @@ +package dotty.tools +package dotc +package cc + +import core._ +import Phases.*, DenotTransformers.*, SymDenotations.* +import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* +import Types.*, StdNames.* +import config.Printers.capt +import ast.tpd +import transform.Recheck.* + +class Setup( + preRecheckPhase: DenotTransformer, + thisPhase: DenotTransformer, + recheckDef: (tpd.ValOrDefDef, Symbol) => Context ?=> Unit) +extends tpd.TreeTraverser: + import tpd.* + + private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type = + MethodType.companion( + isContextual = defn.isContextFunctionClass(tycon.classSymbol), + isErased = defn.isErasedFunctionClass(tycon.classSymbol) + )(argTypes, resType) + .toFunctionType(isJava = false, alwaysDependent = true) + + private def box(tp: Type)(using Context): Type = tp match + case CapturingType(parent, refs, false) => CapturingType(parent, refs, true) + case _ => tp + + private def setBoxed(tp: Type)(using Context) = tp match + case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => + annot.tree.setBoxedCapturing() + case _ => + + private def addBoxes(using Context) = new TypeTraverser: + def traverse(t: Type) = + t match + case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => + args.foreach(setBoxed) + case TypeBounds(lo, hi) => + setBoxed(lo); setBoxed(hi) + case _ => + traverseChildren(t) + + /** Perform the following transformation steps everywhere in a type: + * 1. Drop retains annotations + * 2. Turn plain function types into dependent function types, so that + * we can refer to their parameter in capture sets. Currently this is + * only done at the toplevel, i.e. for function types that are not + * themselves argument types of other function types. Without this restriction + * boxmap-paper.scala fails. Need to figure out why. + * 3. Refine other class types C by adding capture set variables to their parameter getters + * (see addCaptureRefinements) + * 4. Add capture set variables to all types that can be tracked + * + * Polytype bounds are only cleaned using step 1, but not otherwise transformed. + */ + private def mapInferred(using Context) = new TypeMap: + + /** Drop @retains annotations everywhere */ + object cleanup extends TypeMap: + def apply(t: Type) = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case _ => + mapOver(t) + + /** Refine a possibly applied class type C where the class has tracked parameters + * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } + * where CV_1, ..., CV_n are fresh capture sets. + */ + def addCaptureRefinements(tp: Type): Type = tp match + case _: TypeRef | _: AppliedType if tp.typeParams.isEmpty => + tp.typeSymbol match + case cls: ClassSymbol if !defn.isFunctionClass(cls) => + cls.paramGetters.foldLeft(tp) { (core, getter) => + if getter.termRef.isTracked then + val getterType = tp.memberInfo(getter).strippedDealias + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + .showing(i"add capture refinement $tp --> $result", capt) + else + core + } + case _ => tp + case _ => tp + + /** Should a capture set variable be added on type `tp`? */ + def canHaveInferredCapture(tp: Type): Boolean = + tp.typeParams.isEmpty && tp.match + case tp: (TypeRef | AppliedType) => + val sym = tp.typeSymbol + if sym.isClass then !sym.isValueClass && sym != defn.AnyClass + else canHaveInferredCapture(tp.superType.dealias) + case tp: (RefinedOrRecType | MatchType) => + canHaveInferredCapture(tp.underlying) + case tp: AndType => + canHaveInferredCapture(tp.tp1) && canHaveInferredCapture(tp.tp2) + case tp: OrType => + canHaveInferredCapture(tp.tp1) || canHaveInferredCapture(tp.tp2) + case _ => + false + + /** Add a capture set variable to `tp` if necessary, or maybe pull out + * an embedded capture set variables from a part of `tp`. + */ + def addVar(tp: Type) = tp match + case tp @ RefinedType(parent @ CapturingType(parent1, refs, boxed), rname, rinfo) => + CapturingType(tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed) + case tp: RecType => + tp.parent match + case CapturingType(parent1, refs, boxed) => + CapturingType(tp.derivedRecType(parent1), refs, boxed) + case _ => + tp // can return `tp` here since unlike RefinedTypes, RecTypes are never created + // by `mapInferred`. Hence if the underlying type admits capture variables + // a variable was already added, and the first case above would apply. + case AndType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => + assert(refs1.asVar.elems.isEmpty) + assert(refs2.asVar.elems.isEmpty) + assert(boxed1 == boxed2) + CapturingType(AndType(parent1, parent2), refs1, boxed1) + case tp @ OrType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => + assert(refs1.asVar.elems.isEmpty) + assert(refs2.asVar.elems.isEmpty) + assert(boxed1 == boxed2) + CapturingType(OrType(parent1, parent2, tp.isSoft), refs1, boxed1) + case tp @ OrType(CapturingType(parent1, refs1, boxed1), tp2) => + CapturingType(OrType(parent1, tp2, tp.isSoft), refs1, boxed1) + case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) => + CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2) + case _ if canHaveInferredCapture(tp) => + CapturingType(tp, CaptureSet.Var(), boxed = false) + case _ => + tp + + var isTopLevel = true + + def mapNested(ts: List[Type]): List[Type] = + val saved = isTopLevel + isTopLevel = false + try ts.mapConserve(this) finally isTopLevel = saved + + def apply(t: Type) = + val t1 = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case tp @ AppliedType(tycon, args) => + val tycon1 = this(tycon) + if defn.isNonRefinedFunction(tp) then + val args1 = mapNested(args.init) + val res1 = this(args.last) + if isTopLevel then + depFun(tycon1, args1, res1) + .showing(i"add function refinement $tp --> $result", capt) + else + tp.derivedAppliedType(tycon1, args1 :+ res1) + else + tp.derivedAppliedType(tycon1, args.mapConserve(arg => box(this(arg)))) + case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) => + val rinfo1 = apply(rinfo) + if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + else tp + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = mapNested(tp.paramInfos), + resType = this(tp.resType)) + case tp: TypeLambda => + // Don't recurse into parameter bounds, just cleanup any stray retains annotations + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds), + resType = this(tp.resType)) + case _ => + mapOver(t) + addVar(addCaptureRefinements(t1)) + end mapInferred + + private def expandAbbreviations(using Context) = new TypeMap: + + def propagateMethodResult(tp: Type, outerCs: CaptureSet, deep: Boolean): Type = tp match + case tp: MethodType => + if deep then + val tp1 = tp.derivedLambdaType(paramInfos = tp.paramInfos.mapConserve(this)) + propagateMethodResult(tp1, outerCs, deep = false) + else + val localCs = CaptureSet(tp.paramRefs.filter(_.isTracked)*) + tp.derivedLambdaType( + resType = propagateEnclosing(tp.resType, CaptureSet.empty, outerCs ++ localCs)) + + def propagateDepFunctionResult(tp: Type, outerCs: CaptureSet, deep: Boolean): Type = tp match + case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) => + val rinfo1 = propagateMethodResult(rinfo, outerCs, deep) + if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + else tp + + def propagateEnclosing(tp: Type, currentCs: CaptureSet, outerCs: CaptureSet): Type = tp match + case tp @ AppliedType(tycon, args) if defn.isFunctionClass(tycon.typeSymbol) => + val tycon1 = this(tycon) + val args1 = args.init.mapConserve(this) + val tp1 = + if args1.exists(!_.captureSet.isAlwaysEmpty) then + val propagated = propagateDepFunctionResult( + depFun(tycon, args1, args.last), currentCs ++ outerCs, deep = false) + propagated match + case RefinedType(_, _, mt: MethodType) => + val following = mt.resType.captureSet.elems + if mt.paramRefs.exists(following.contains(_)) then propagated + else tp.derivedAppliedType(tycon1, args1 :+ mt.resType) + else + val resType1 = propagateEnclosing( + args.last, CaptureSet.empty, currentCs ++ outerCs) + tp.derivedAppliedType(tycon1, args1 :+ resType1) + tp1.capturing(outerCs) + case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) => + propagateDepFunctionResult(tp, currentCs ++ outerCs, deep = true) + .capturing(outerCs) + case _ => + mapOver(tp) + + def apply(tp: Type): Type = tp match + case CapturingType(parent, cs, boxed) => + tp.derivedCapturingType(propagateEnclosing(parent, cs, CaptureSet.empty), cs) + case _ => + propagateEnclosing(tp, CaptureSet.empty, CaptureSet.empty) + end expandAbbreviations + + private def transformInferredType(tp: Type, boxed: Boolean)(using Context): Type = + val tp1 = mapInferred(tp) + if boxed then box(tp1) else tp1 + + private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type = + addBoxes.traverse(tp) + if boxed then setBoxed(tp) + if ctx.settings.YccNoAbbrev.value then tp + else expandAbbreviations(tp) + + // Substitute parameter symbols in `from` to paramRefs in corresponding + // method or poly types `to`. We use a single BiTypeMap to do everything. + private class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) + extends DeepTypeMap, BiTypeMap: + + def apply(t: Type): Type = t match + case t: NamedType => + val sym = t.symbol + def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type = + def inner(from: List[Symbol], to: List[ParamRef]): Type = + if from.isEmpty then outer(froms.tail, tos.tail) + else if sym eq from.head then to.head + else inner(from.tail, to.tail) + if tos.isEmpty then t + else inner(froms.head, tos.head.paramRefs) + outer(from, to) + case _ => + mapOver(t) + + def inverse(t: Type): Type = t match + case t: ParamRef => + def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = + if from.isEmpty then t + else if t.binder eq from.head then to.head(t.paramNum).namedType + else recur(from.tail, to.tail) + recur(to, from) + case _ => + mapOver(t) + end SubstParams + + private def transformTT(tree: TypeTree, boxed: Boolean)(using Context) = + tree.rememberType( + if tree.isInstanceOf[InferredTypeTree] + then transformInferredType(tree.tpe, boxed) + else transformExplicitType(tree.tpe, boxed)) + + def traverse(tree: Tree)(using Context) = + tree match + case tree @ ValDef(_, tpt: TypeTree, _) if tree.symbol.is(Mutable) => + transformTT(tpt, boxed = true) + traverse(tree.rhs) + case _ => + traverseChildren(tree) + tree match + case tree: TypeTree => + transformTT(tree, boxed = false) + case tree: ValOrDefDef => + val sym = tree.symbol + + // replace an existing symbol info with inferred types + def integrateRT( + info: Type, // symbol info to replace + psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info` + prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order + prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order + ): Type = + info match + case mt: MethodOrPoly => + val psyms = psymss.head + mt.companion(mt.paramNames)( + mt1 => + if !psyms.exists(_.isUpdatedAfter(preRecheckPhase)) && !mt.isParamDependent && prevLambdas.isEmpty then + mt.paramInfos + else + val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) + psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), + mt1 => + integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) + ) + case info: ExprType => + info.derivedExprType(resType = + integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) + case _ => + val restp = tree.tpt.knownType + if prevLambdas.isEmpty then restp + else SubstParams(prevPsymss, prevLambdas)(restp) + + if tree.tpt.hasRememberedType && !sym.isConstructor then + val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) + .showing(i"update info $sym: ${sym.info} --> $result", capt) + if newInfo ne sym.info then + val completer = new LazyType: + def complete(denot: SymDenotation)(using Context) = + denot.info = newInfo + recheckDef(tree, sym) + sym.updateInfoBetween(preRecheckPhase, thisPhase, completer) + case tree: Bind => + val sym = tree.symbol + sym.updateInfoBetween(preRecheckPhase, thisPhase, + transformInferredType(sym.info, boxed = false)) + case _ => +end Setup diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 4499e090a212..4e84d9f67319 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -314,6 +314,7 @@ private sealed trait YSettings: val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)") val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references") val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Debug info for captured references") + val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with -Ycc, suppress type abbreviations") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 3beffbbf3b05..b5910a7dac7b 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -896,6 +896,10 @@ object Parsers { def followingIsCaptureSet(): Boolean = val lookahead = in.LookaheadScanner() + def followingIsTypeStart() = + lookahead.nextToken() + canStartInfixTypeTokens.contains(lookahead.token) + || lookahead.token == LBRACKET def recur(): Boolean = (lookahead.isIdent || lookahead.token == THIS) && { lookahead.nextToken() @@ -903,14 +907,10 @@ object Parsers { lookahead.nextToken() recur() else - lookahead.token == RBRACE && { - lookahead.nextToken() - canStartInfixTypeTokens.contains(lookahead.token) - || lookahead.token == LBRACKET - } + lookahead.token == RBRACE && followingIsTypeStart() } lookahead.nextToken() - recur() + if lookahead.token == RBRACE then followingIsTypeStart() else recur() /* --------- OPERAND/OPERATOR STACK --------------------------------------- */ @@ -1486,7 +1486,9 @@ object Parsers { else { accept(TLARROW); typ() } } else if in.token == LBRACE && followingIsCaptureSet() then - val refs = inBraces { commaSeparated(captureRef) } + val refs = inBraces { + if in.token == RBRACE then Nil else commaSeparated(captureRef) + } val t = typ() CapturingTypeTree(refs, t) else if (in.token == INDENT) enclosed(INDENT, typ()) diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 88fac70581ab..06cf18b9dcef 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -30,7 +30,7 @@ import config.Config import language.implicitConversions import dotty.tools.dotc.util.{NameTransformer, SourcePosition} import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef} -import cc.{EventuallyCapturingType, CaptureSet, toCaptureSet, IllegalCaptureRef} +import cc.{CaptureSet, toCaptureSet, IllegalCaptureRef} class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 924a444aeff4..bdb49c5a8edb 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -7,7 +7,7 @@ import Symbols.*, Contexts.*, Types.*, ContextOps.*, Decorators.*, SymDenotation import Flags.*, SymUtils.*, NameKinds.* import ast.* import Phases.Phase -import DenotTransformers.IdentityDenotTransformer +import DenotTransformers.{IdentityDenotTransformer, DenotTransformer} import NamerOps.{methodType, linkConstructorParams} import NullOpsDecorator.stripNull import typer.ErrorReporting.err @@ -22,9 +22,48 @@ import reporting.trace object Recheck: + import tpd.Tree + /** Attachment key for rechecked types of TypeTrees */ private val RecheckedType = Property.Key[Type] + extension (sym: Symbol) + + /** Update symbol's info to newInfo from prevPhase.next to lastPhase. + * Reset to previous info for phases after lastPhase. + */ + def updateInfoBetween(prevPhase: DenotTransformer, lastPhase: DenotTransformer, newInfo: Type)(using Context): Unit = + if sym.info ne newInfo then + sym.copySymDenotation().installAfter(lastPhase) // reset + sym.copySymDenotation( + info = newInfo, + initFlags = + if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched + else sym.flags + ).installAfter(prevPhase) + + /** Does symbol have a new denotation valid from phase.next that is different + * from the denotation it had before? + */ + def isUpdatedAfter(phase: Phase)(using Context) = + val symd = sym.denot + symd.validFor.firstPhaseId == phase.id + 1 && (sym.originDenotation ne symd) + + extension (tree: Tree) + + /** Remember `tpe` as the type of `tree`, which might be different from the + * type stored in the tree itself. + */ + def rememberType(tpe: Type)(using Context): Unit = + if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then + tree.putAttachment(RecheckedType, tpe) + + /** The remembered type of the tree, or if none was installed, the original type */ + def knownType = + tree.attachmentOrElse(RecheckedType, tree.tpe) + + def hasRememberedType: Boolean = tree.hasAttachment(RecheckedType) + abstract class Recheck extends Phase, IdentityDenotTransformer: thisPhase => @@ -43,9 +82,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: override def widenSkolems = true def run(using Context): Unit = - val rechecker = newRechecker() - rechecker.transformTypes.traverse(ctx.compilationUnit.tpdTree) - rechecker.checkUnit(ctx.compilationUnit) + newRechecker().checkUnit(ctx.compilationUnit) def newRechecker()(using Context): Rechecker @@ -55,120 +92,6 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: ictx.settings.Xprint.value.containsPhase(thisPhase) } - extension (sym: Symbol) def updateInfo(newInfo: Type)(using Context): Unit = - if sym.info ne newInfo then - sym.copySymDenotation().installAfter(thisPhase) // reset - sym.copySymDenotation( - info = newInfo, - initFlags = - if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched - else sym.flags - ).installAfter(preRecheckPhase) - - extension (tpe: Type) def rememberFor(tree: Tree)(using Context): Unit = - if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then - tree.putAttachment(RecheckedType, tpe) - - def knownType(tree: Tree) = - tree.attachmentOrElse(RecheckedType, tree.tpe) - - def isUpdated(sym: Symbol)(using Context) = - val symd = sym.denot - symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd) - - def transformType(tp: Type, inferred: Boolean, boxed: Boolean = false)(using Context): Type = tp - - object transformTypes extends TreeTraverser: - - // Substitute parameter symbols in `from` to paramRefs in corresponding - // method or poly types `to`. We use a single BiTypeMap to do everything. - class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) - extends DeepTypeMap, BiTypeMap: - - def apply(t: Type): Type = t match - case t: NamedType => - val sym = t.symbol - def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type = - def inner(from: List[Symbol], to: List[ParamRef]): Type = - if from.isEmpty then outer(froms.tail, tos.tail) - else if sym eq from.head then to.head - else inner(from.tail, to.tail) - if tos.isEmpty then t - else inner(froms.head, tos.head.paramRefs) - outer(from, to) - case _ => - mapOver(t) - - def inverse(t: Type): Type = t match - case t: ParamRef => - def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = - if from.isEmpty then t - else if t.binder eq from.head then to.head(t.paramNum).namedType - else recur(from.tail, to.tail) - recur(to, from) - case _ => - mapOver(t) - end SubstParams - - private def transformTT(tree: TypeTree, boxed: Boolean)(using Context) = - transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree], boxed).rememberFor(tree) - - def traverse(tree: Tree)(using Context) = - tree match - case tree @ ValDef(_, tpt: TypeTree, _) if tree.symbol.is(Mutable) => - transformTT(tpt, boxed = true) - traverse(tree.rhs) - case _ => - traverseChildren(tree) - tree match - case tree: TypeTree => - transformTT(tree, boxed = false) - case tree: ValOrDefDef => - val sym = tree.symbol - - // replace an existing symbol info with inferred types - def integrateRT( - info: Type, // symbol info to replace - psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info` - prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order - prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order - ): Type = - info match - case mt: MethodOrPoly => - val psyms = psymss.head - mt.companion(mt.paramNames)( - mt1 => - if !psyms.exists(isUpdated) && !mt.isParamDependent && prevLambdas.isEmpty then - mt.paramInfos - else - val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) - psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), - mt1 => - integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) - ) - case info: ExprType => - info.derivedExprType(resType = - integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) - case _ => - val restp = knownType(tree.tpt) - if prevLambdas.isEmpty then restp - else SubstParams(prevPsymss, prevLambdas)(restp) - - if tree.tpt.hasAttachment(RecheckedType) && !sym.isConstructor then - val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) - .showing(i"update info $sym: ${sym.info} --> $result", recheckr) - if newInfo ne sym.info then - val completer = new LazyType: - def complete(denot: SymDenotation)(using Context) = - denot.info = newInfo - recheckDef(tree, sym) - sym.updateInfo(completer) - case tree: Bind => - val sym = tree.symbol - sym.updateInfo(transformType(sym.info, inferred = true)) - case _ => - end transformTypes - def constFold(tree: Tree, tp: Type)(using Context): Type = val tree1 = tree.withType(tp) val tree2 = ConstFold(tree1) @@ -375,7 +298,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: ValOrDefDef => if tree.isEmpty then NoType else - if isUpdated(sym) then sym.ensureCompleted() + if sym.isUpdatedAfter(preRecheckPhase) then sym.ensureCompleted() else recheckDef(tree, sym) sym.termRef case tree: TypeDef => @@ -414,7 +337,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = checkConforms(tpe, pt, tree) - if keepTypes then tpe.rememberFor(tree) + if keepTypes then tree.rememberType(tpe) tpe def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 3cdf39cb56e6..bfe93514548e 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -5,26 +5,17 @@ package cc import core._ import Phases.*, DenotTransformers.*, SymDenotations.* import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* -import Types._ -import Symbols._ -import StdNames._ -import Decorators._ +import Types.*, StdNames.* import config.Printers.{capt, recheckr} import ast.{tpd, untpd, Trees} -import NameKinds.{DocArtifactName, OuterSelectName, DefaultGetterName} -import Trees._ -import scala.util.control.NonFatal -import typer.ErrorReporting._ -import typer.RefChecks -import util.Spans.Span +import Trees.* +import typer.RefChecks.checkAllOverrides import util.{SimpleIdentitySet, EqHashMap, SrcPos} -import util.Chars.* -import transform.* import transform.SymUtils.* +import transform.Recheck +import Recheck.* import scala.collection.mutable -import reporting._ -import dotty.tools.backend.jvm.DottyBackendInterface.symExtensions -import CaptureSet.{CompareResult, withCaptureSetsExplained} +import CaptureSet.withCaptureSetsExplained object CheckCaptures: import ast.tpd.* @@ -108,180 +99,13 @@ class CheckCaptures extends Recheck: // ^^^ TODO: Can we avoid doing overrides checks twice? // We need to do them here since only at this phase CaptureTypes are relevant // But maybe we can then elide the check during the RefChecks phase if -Ycc is set? - RefChecks.checkAllOverrides(ctx.owner.asClass) + checkAllOverrides(ctx.owner.asClass) case _ => traverseChildren(t) class CaptureChecker(ictx: Context) extends Rechecker(ictx): import ast.tpd.* - override def transformType(tp: Type, inferred: Boolean, boxed: Boolean)(using Context): Type = - - def depFun(tycon: Type, argTypes: List[Type], resType: Type): Type = - MethodType.companion( - isContextual = defn.isContextFunctionClass(tycon.classSymbol), - isErased = defn.isErasedFunctionClass(tycon.classSymbol) - )(argTypes, resType) - .toFunctionType(isJava = false, alwaysDependent = true) - - def box(tp: Type): Type = tp match - case CapturingType(parent, refs, false) => CapturingType(parent, refs, true) - case _ => tp - - /** Perform the following transformation steps everywhere in a type: - * 1. Drop retains annotations - * 2. Turn plain function types into dependent function types, so that - * we can refer to their parameter in capture sets. Currently this is - * only done at the toplevel, i.e. for function types that are not - * themselves argument types of other function types. Without this restriction - * boxmap-paper.scala fails. Need to figure out why. - * 3. Refine other class types C by adding capture set variables to their parameter getters - * (see addCaptureRefinements) - * 4. Add capture set variables to all types that can be tracked - * - * Polytype bounds are only cleaned using step 1, but not otherwise transformed. - */ - def mapInferred = new TypeMap: - - /** Drop @retains annotations everywhere */ - object cleanup extends TypeMap: - def apply(t: Type) = t match - case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => - apply(parent) - case _ => - mapOver(t) - - /** Refine a possibly applied class type C where the class has tracked parameters - * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } - * where CV_1, ..., CV_n are fresh capture sets. - */ - def addCaptureRefinements(tp: Type): Type = tp match - case _: TypeRef | _: AppliedType if tp.typeParams.isEmpty => - tp.typeSymbol match - case cls: ClassSymbol if !defn.isFunctionClass(cls) => - cls.paramGetters.foldLeft(tp) { (core, getter) => - if getter.termRef.isTracked then - val getterType = tp.memberInfo(getter).strippedDealias - RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) - .showing(i"add capture refinement $tp --> $result", capt) - else - core - } - case _ => tp - case _ => tp - - /** Should a capture set variable be added on type `tp`? */ - def canHaveInferredCapture(tp: Type): Boolean = - tp.typeParams.isEmpty && tp.match - case tp: (TypeRef | AppliedType) => - val sym = tp.typeSymbol - if sym.isClass then !sym.isValueClass && sym != defn.AnyClass - else canHaveInferredCapture(tp.superType.dealias) - case tp: (RefinedOrRecType | MatchType) => - canHaveInferredCapture(tp.underlying) - case tp: AndType => - canHaveInferredCapture(tp.tp1) && canHaveInferredCapture(tp.tp2) - case tp: OrType => - canHaveInferredCapture(tp.tp1) || canHaveInferredCapture(tp.tp2) - case _ => - false - - /** Add a capture set variable to `tp` if necessary, or maybe pull out - * an embedded capture set variables from a part of `tp`. - */ - def addVar(tp: Type) = tp match - case tp @ RefinedType(parent @ CapturingType(parent1, refs, boxed), rname, rinfo) => - CapturingType(tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed) - case tp: RecType => - tp.parent match - case CapturingType(parent1, refs, boxed) => - CapturingType(tp.derivedRecType(parent1), refs, boxed) - case _ => - tp // can return `tp` here since unlike RefinedTypes, RecTypes are never created - // by `mapInferred`. Hence if the underlying type admits capture variables - // a variable was already added, and the first case above would apply. - case AndType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => - assert(refs1.asVar.elems.isEmpty) - assert(refs2.asVar.elems.isEmpty) - assert(boxed1 == boxed2) - CapturingType(AndType(parent1, parent2), refs1, boxed1) - case tp @ OrType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) => - assert(refs1.asVar.elems.isEmpty) - assert(refs2.asVar.elems.isEmpty) - assert(boxed1 == boxed2) - CapturingType(OrType(parent1, parent2, tp.isSoft), refs1, boxed1) - case tp @ OrType(CapturingType(parent1, refs1, boxed1), tp2) => - CapturingType(OrType(parent1, tp2, tp.isSoft), refs1, boxed1) - case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) => - CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2) - case _ if canHaveInferredCapture(tp) => - CapturingType(tp, CaptureSet.Var(), boxed = false) - case _ => - tp - - var isTopLevel = true - - def mapNested(ts: List[Type]): List[Type] = - val saved = isTopLevel - isTopLevel = false - try ts.mapConserve(this) finally isTopLevel = saved - - def apply(t: Type) = - val t1 = t match - case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => - apply(parent) - case tp @ AppliedType(tycon, args) => - val tycon1 = this(tycon) - if defn.isNonRefinedFunction(tp) then - val args1 = mapNested(args.init) - val res1 = this(args.last) - if isTopLevel then - depFun(tycon1, args1, res1) - .showing(i"add function refinement $tp --> $result", capt) - else - tp.derivedAppliedType(tycon1, args1 :+ res1) - else - tp.derivedAppliedType(tycon1, args.mapConserve(arg => box(this(arg)))) - case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) => - apply(rinfo).toFunctionType(isJava = false, alwaysDependent = true) - case tp: MethodType => - tp.derivedLambdaType( - paramInfos = mapNested(tp.paramInfos), - resType = this(tp.resType)) - case tp: TypeLambda => - // Don't recurse into parameter bounds, just cleanup any stray retains annotations - tp.derivedLambdaType( - paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds), - resType = this(tp.resType)) - case _ => - mapOver(t) - addVar(addCaptureRefinements(t1)) - end mapInferred - - if inferred then - val tp1 = mapInferred(tp) - if boxed then box(tp1) else tp1 - else - def setBoxed(t: Type) = t match - case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => - annot.tree.setBoxedCapturing() - case _ => - - val addBoxes = new TypeTraverser: - def traverse(t: Type) = - t match - case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => - args.foreach(setBoxed) - case TypeBounds(lo, hi) => - setBoxed(lo); setBoxed(hi) - case _ => - traverseChildren(t) - - if boxed then setBoxed(tp) - addBoxes.traverse(tp) - tp - end transformType - private def interpolator(using Context) = new TypeTraverser: override def traverse(t: Type) = t match @@ -299,8 +123,8 @@ class CheckCaptures extends Recheck: private def interpolateVarsIn(tpt: Tree)(using Context): Unit = if tpt.isInstanceOf[InferredTypeTree] then - interpolator.traverse(knownType(tpt)) - .showing(i"solved vars in ${knownType(tpt)}", capt) + interpolator.traverse(tpt.knownType) + .showing(i"solved vars in ${tpt.knownType}", capt) private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null) @@ -477,6 +301,8 @@ class CheckCaptures extends Recheck: res override def checkUnit(unit: CompilationUnit)(using Context): Unit = + Setup(preRecheckPhase, thisPhase, recheckDef) + .traverse(ctx.compilationUnit.tpdTree) withCaptureSetsExplained { super.checkUnit(unit) PostRefinerCheck.traverse(unit.tpdTree) @@ -494,12 +320,12 @@ class CheckCaptures extends Recheck: val notAllowed = i" is not allowed to capture the $what capability $ref" def msg = if allArgs.isEmpty then - i"type of mutable variable ${knownType(tree)}$notAllowed" + i"type of mutable variable ${tree.knownType}$notAllowed" else tree match case tree: InferredTypeTree => - i"""inferred type argument ${knownType(tree)}$notAllowed + i"""inferred type argument ${tree.knownType}$notAllowed | - |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + |The inferred arguments are: [${allArgs.map(_.knownType)}%, %]""" case _ => s"type argument$notAllowed" report.error(msg, tree.srcPos) @@ -509,7 +335,7 @@ class CheckCaptures extends Recheck: case LambdaTypeTree(_, restpt) => checkNotGlobal(restpt, allArgs*) case _ => - checkNotGlobal(tree, knownType(tree), allArgs*) + checkNotGlobal(tree, tree.knownType, allArgs*) def checkNotGlobalDeep(tree: Tree)(using Context): Unit = val checker = new TypeTraverser: @@ -522,23 +348,23 @@ class CheckCaptures extends Recheck: case _ => checkNotGlobal(tree, tp) traverseChildren(tp) - checker.traverse(knownType(tree)) + checker.traverse(tree.knownType) object PostRefinerCheck extends TreeTraverser: def traverse(tree: Tree)(using Context) = tree match case _: InferredTypeTree => case tree: TypeTree if !tree.span.isZeroExtent => - knownType(tree).foreachPart( + tree.knownType.foreachPart( checkWellformedPost(_, tree.srcPos)) - knownType(tree).foreachPart { + tree.knownType.foreachPart { case AnnotatedType(_, annot) => checkWellformedPost(annot.tree) case _ => } case tree1 @ TypeApply(fn, args) if disallowGlobal => for arg <- args do - //println(i"checking $arg in $tree: ${knownType(tree).captureSet}") + //println(i"checking $arg in $tree: ${tree.knownType.captureSet}") checkNotGlobal(arg, args*) case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] => val sym = t.symbol @@ -551,7 +377,7 @@ class CheckCaptures extends Recheck: || sym.owner.is(Trait) // ... since we do OverridingPairs checking before capture inference || sym.allOverriddenSymbols.nonEmpty // ... since we do override checking before capture inference then - val inferred = knownType(t.tpt) + val inferred = t.tpt.knownType def checkPure(tp: Type) = tp match case CapturingType(_, refs, _) if !refs.elems.isEmpty => val resultStr = if t.isInstanceOf[DefDef] then " result" else "" diff --git a/tests/neg-custom-args/captures/curried-simplified.check b/tests/neg-custom-args/captures/curried-simplified.check new file mode 100644 index 000000000000..055558530a76 --- /dev/null +++ b/tests/neg-custom-args/captures/curried-simplified.check @@ -0,0 +1,42 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:7:28 ---------------------------- +7 | def y1: () -> () -> Int = x1 // error + | ^^ + | Found: {x} () -> {x} () -> Int + | Required: () -> () -> Int + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:9:28 ---------------------------- +9 | def y2: () -> () => Int = x2 // error + | ^^ + | Found: {x} () -> () => Int + | Required: () -> () => Int + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:11:39 --------------------------- +11 | def y3: Cap -> Protect[Int -> Int] = x3 // error + | ^^ + | Found: (x$0: Cap) -> {x$0} Int -> Int + | Required: Cap -> Protect[Int -> Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:15:33 --------------------------- +15 | def y5: Cap -> {} Int -> Int = x5 // error + | ^^ + | Found: Cap -> {x} Int -> Int + | Required: Cap -> {} Int -> Int + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:17:49 --------------------------- +17 | def y6: Cap -> {} Cap -> Protect[Int -> Int] = x6 // error + | ^^ + | Found: (x$0: Cap) -> {x$0} (x$0: Cap) -> {x$0, x$0} Int -> Int + | Required: Cap -> {} Cap -> Protect[Int -> Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/curried-simplified.scala:19:49 --------------------------- +19 | def y7: Cap -> Protect[Cap -> {} Int -> Int] = x7 // error + | ^^ + | Found: (x$0: Cap) -> {x$0} (x: Cap) -> {x$0, x} Int -> Int + | Required: Cap -> Protect[Cap -> {} Int -> Int] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/curried-simplified.scala b/tests/neg-custom-args/captures/curried-simplified.scala new file mode 100644 index 000000000000..25b23370d154 --- /dev/null +++ b/tests/neg-custom-args/captures/curried-simplified.scala @@ -0,0 +1,21 @@ +@annotation.capability class Cap + +type Protect[T] = T + +def test(x: Cap, y: Cap) = + def x1: {x} () -> () -> Int = ??? + def y1: () -> () -> Int = x1 // error + def x2: {x} () -> () => Int = ??? + def y2: () -> () => Int = x2 // error + def x3: Cap -> Int -> Int = ??? + def y3: Cap -> Protect[Int -> Int] = x3 // error + def x4: Cap -> Protect[Int -> Int] = ??? + def y4: Cap -> {} Int -> Int = x4 // ok + def x5: Cap -> {x} Int -> Int = ??? + def y5: Cap -> {} Int -> Int = x5 // error + def x6: Cap -> Cap -> Int -> Int = ??? + def y6: Cap -> {} Cap -> Protect[Int -> Int] = x6 // error + def x7: Cap -> (x: Cap) -> Int -> Int = ??? + def y7: Cap -> Protect[Cap -> {} Int -> Int] = x7 // error + + diff --git a/tests/pos-custom-args/captures/capt-depfun.scala b/tests/pos-custom-args/captures/capt-depfun.scala index 072eaefd3e78..808b1b5e85f3 100644 --- a/tests/pos-custom-args/captures/capt-depfun.scala +++ b/tests/pos-custom-args/captures/capt-depfun.scala @@ -3,6 +3,8 @@ type Cap = C @retains(*) type T = (x: Cap) -> String @retains(x) +type ID[X] = X + val aa: ((x: Cap) -> String @retains(x)) = (x: Cap) => "" def f(y: Cap, z: Cap): String @retains(*) = @@ -12,7 +14,7 @@ def f(y: Cap, z: Cap): String @retains(*) = def g(): C @retains(y, z) = ??? val d = a(g()) - val ac: ((x: Cap) -> String @retains(x) -> String @retains(x)) = ??? + val ac: ((x: Cap) -> ID[String @retains(x) -> String @retains(x)]) = ??? val bc: (({y} String) -> {y} String) = ac(y) val dc: (String -> {y, z} String) = ac(g()) c From 670e5103b16e3bb0ed9edcf52e4f047f507381ad Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 26 Jan 2022 16:26:19 +0100 Subject: [PATCH 15/24] Syntax change: allow capture sets in infix types --- .../src/dotty/tools/dotc/parsing/Parsers.scala | 18 +++++++++++------- tests/disabled/pos/lazylist.scala | 2 +- tests/pos-custom-args/captures/classes.scala | 4 ++-- tests/pos-custom-args/captures/iterators.scala | 2 +- tests/pos-custom-args/captures/lazylists.scala | 8 ++++---- .../pos-custom-args/captures/lazylists1.scala | 4 ++-- 6 files changed, 21 insertions(+), 17 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index b5910a7dac7b..952a6022428e 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1486,11 +1486,7 @@ object Parsers { else { accept(TLARROW); typ() } } else if in.token == LBRACE && followingIsCaptureSet() then - val refs = inBraces { - if in.token == RBRACE then Nil else commaSeparated(captureRef) - } - val t = typ() - CapturingTypeTree(refs, t) + CapturingTypeTree(captureSet(), typ()) else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() @@ -1873,8 +1869,14 @@ object Parsers { def typeDependingOn(location: Location): Tree = if location.inParens then typ() else if location.inPattern then rejectWildcardType(refinedType()) + else if in.token == LBRACE && followingIsCaptureSet() then + CapturingTypeTree(captureSet(), infixType()) else infixType() + def captureSet(): List[Tree] = inBraces { + if in.token == RBRACE then Nil else commaSeparated(captureRef) + } + /* ----------- EXPRESSIONS ------------------------------------------------ */ /** Does the current conditional expression continue after @@ -1944,7 +1946,7 @@ object Parsers { * | ‘inline’ InfixExpr MatchClause * Bindings ::= `(' [Binding {`,' Binding}] `)' * Binding ::= (id | `_') [`:' Type] - * Ascription ::= `:' InfixType + * Ascription ::= `:' [CaptureSet] InfixType * | `:' Annotation {Annotation} * | `:' `_' `*' * Catches ::= ‘catch’ (Expr | ExprCaseClause) @@ -3900,7 +3902,7 @@ object Parsers { stats.toList } - /** TemplateStatSeq ::= [id [`:' Type] `=>'] TemplateStat {semi TemplateStat} + /** TemplateStatSeq ::= [SelfType] TemplateStat {semi TemplateStat} * TemplateStat ::= Import * | Export * | Annotations Modifiers Def @@ -3910,6 +3912,8 @@ object Parsers { * | * EnumStat ::= TemplateStat * | Annotations Modifiers EnumCase + * SelfType ::= id [‘:’ [CaptureSet] InfixType] ‘=>’ + * | ‘this’ ‘:’ [CaptureSet] InfixType ‘=>’ */ def templateStatSeq(): (ValDef, List[Tree]) = checkNoEscapingPlaceholders { var self: ValDef = EmptyValDef diff --git a/tests/disabled/pos/lazylist.scala b/tests/disabled/pos/lazylist.scala index 958f4c35aaf0..c24f8677b91f 100644 --- a/tests/disabled/pos/lazylist.scala +++ b/tests/disabled/pos/lazylist.scala @@ -1,7 +1,7 @@ package lazylists abstract class LazyList[+T]: - this: ({*} LazyList[T]) => + this: {*} LazyList[T] => def isEmpty: Boolean def head: T diff --git a/tests/pos-custom-args/captures/classes.scala b/tests/pos-custom-args/captures/classes.scala index f3d6e44b27ca..243f70e02899 100644 --- a/tests/pos-custom-args/captures/classes.scala +++ b/tests/pos-custom-args/captures/classes.scala @@ -1,7 +1,7 @@ class B type Cap = {*} B class C(val n: Cap): - this: ({n} C) => + this: {n} C => def foo(): {n} B = n @@ -9,7 +9,7 @@ def test(x: Cap, y: Cap, z: Cap) = val c0 = C(x) val c1: {x} C {val n: {x} B} = c0 val d = c1.foo() - d: ({x} B) + d: {x} B val c2 = if ??? then C(x) else C(y) val c2a = identity(c2) diff --git a/tests/pos-custom-args/captures/iterators.scala b/tests/pos-custom-args/captures/iterators.scala index 1ac1bd96f6d7..50be2012e25c 100644 --- a/tests/pos-custom-args/captures/iterators.scala +++ b/tests/pos-custom-args/captures/iterators.scala @@ -1,7 +1,7 @@ package cctest abstract class Iterator[T]: - thisIterator: ({*} Iterator[T]) => + thisIterator: {*} Iterator[T] => def hasNext: Boolean def next: T diff --git a/tests/pos-custom-args/captures/lazylists.scala b/tests/pos-custom-args/captures/lazylists.scala index bf3e0300b5b5..c566bea8dd64 100644 --- a/tests/pos-custom-args/captures/lazylists.scala +++ b/tests/pos-custom-args/captures/lazylists.scala @@ -2,7 +2,7 @@ class CC type Cap = {*} CC trait LazyList[+A]: - this: ({*} LazyList[A]) => + this: {*} LazyList[A] => def isEmpty: Boolean def head: A @@ -16,12 +16,12 @@ object LazyNil extends LazyList[Nothing]: extension [A](xs: {*} LazyList[A]) def map[B](f: A => B): {xs, f} LazyList[B] = final class Mapped extends LazyList[B]: - this: ({xs, f} Mapped) => + this: {xs, f} Mapped => def isEmpty = false def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK - def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // OK + def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK if xs.isEmpty then LazyNil else new Mapped @@ -31,7 +31,7 @@ def test(cap1: Cap, cap2: Cap) = val xs = class Initial extends LazyList[String]: - this: ({cap1} Initial) => + this: {cap1} Initial => def isEmpty = false def head = f("") diff --git a/tests/pos-custom-args/captures/lazylists1.scala b/tests/pos-custom-args/captures/lazylists1.scala index 7fbdde87ad9b..2dbb5ac232e2 100644 --- a/tests/pos-custom-args/captures/lazylists1.scala +++ b/tests/pos-custom-args/captures/lazylists1.scala @@ -2,7 +2,7 @@ class CC type Cap = {*} CC trait LazyList[+A]: - this: ({*} LazyList[A]) => + this: {*} LazyList[A] => def isEmpty: Boolean def head: A @@ -14,7 +14,7 @@ object LazyNil extends LazyList[Nothing]: def tail = ??? final class LazyCons[+T](val x: T, val xs: Int => {*} LazyList[T]) extends LazyList[T]: - this: ({*} LazyList[T]) => + this: {*} LazyList[T] => def isEmpty = false def head = x From 7ba7c898df4ac503c7a308df8dde933caf42e903 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 26 Jan 2022 16:26:44 +0100 Subject: [PATCH 16/24] Test for contravariantly used class fields As discussed in the CC meeting on 21 Jan 2022 --- .../dotty/tools/dotc/core/tasty/TreePickler.scala | 1 + tests/neg-custom-args/captures/class-contra.check | 7 +++++++ tests/neg-custom-args/captures/class-contra.scala | 13 +++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 tests/neg-custom-args/captures/class-contra.check create mode 100644 tests/neg-custom-args/captures/class-contra.scala diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index c83d65ff483d..9a659340cd52 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -203,6 +203,7 @@ class TreePickler(pickler: TastyPickler) { writeByte(if (tpe.isType) TYPEREFdirect else TERMREFdirect) if !symRefs.contains(sym) && !sym.isPatternBound && !sym.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then report.log(i"pickling reference to as yet undefined $tpe with symbol ${sym}", sym.srcPos) + // todo: find out why this happens for pos-customargs/captures/capt2 pickleSymRef(sym) } else tpe.designator match { diff --git a/tests/neg-custom-args/captures/class-contra.check b/tests/neg-custom-args/captures/class-contra.check new file mode 100644 index 000000000000..3825d57b602e --- /dev/null +++ b/tests/neg-custom-args/captures/class-contra.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/class-contra.scala:12:39 --------------------------------- +12 | def fun(x: K{val f: {a} T}) = x.setf(a) // error + | ^ + | Found: (a : {x, y} T) + | Required: T + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/class-contra.scala b/tests/neg-custom-args/captures/class-contra.scala new file mode 100644 index 000000000000..270aaf9309a9 --- /dev/null +++ b/tests/neg-custom-args/captures/class-contra.scala @@ -0,0 +1,13 @@ + +class C +type Cap = {*} C + +class K(val f: {*} T): + def setf(x: {f} T) = ??? + +class T + +def test(x: Cap, y: Cap) = + val a: {x, y} T = ??? + def fun(x: K{val f: {a} T}) = x.setf(a) // error + () \ No newline at end of file From 42c9eedff1fa1bf323d740514b36df8c33dc4826 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 23 Jan 2022 16:21:24 +0100 Subject: [PATCH 17/24] Handle captures in by-name parameters 1. Infrastructure to deal with capturesets in byname parameters 2. Handle retainsByName annotations in ElimByName Convert them to regular annotations on the generated function types. This enables capture checking on by-name parameters. 3. Add a style warning for misleading by-name parameter type formatting. By-name types should be formatted `{...}-> T`. `{...} -> T` looks too much like a function type. --- .../src/dotty/tools/dotc/ast/Desugar.scala | 14 ++-- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 23 ++++++- .../tools/dotc/cc/CaptureAnnotation.scala | 20 +++--- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 7 +- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 5 +- .../dotty/tools/dotc/cc/CapturingKind.scala | 9 +++ .../dotty/tools/dotc/cc/CapturingType.scala | 38 +++++++--- compiler/src/dotty/tools/dotc/cc/Setup.scala | 7 +- .../dotty/tools/dotc/core/Definitions.scala | 20 ++++-- .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotty/tools/dotc/core/TypeComparer.scala | 4 +- .../src/dotty/tools/dotc/core/Types.scala | 15 ++-- .../dotty/tools/dotc/parsing/Parsers.scala | 69 ++++++++++++++----- .../tools/dotc/printing/PlainPrinter.scala | 15 ++-- .../tools/dotc/printing/RefinedPrinter.scala | 9 ++- .../dotty/tools/dotc/reporting/messages.scala | 2 +- .../tools/dotc/typer/ErrorReporting.scala | 8 ++- .../src/dotty/tools/dotc/typer/Typer.scala | 3 +- .../scala/retainsByName.scala | 6 ++ tests/neg-custom-args/captures/byname.check | 20 ++++++ tests/neg-custom-args/captures/byname.scala | 10 +++ tests/neg-custom-args/captures/lazylist.check | 4 +- .../neg-custom-args/captures/lazylists2.check | 4 +- tests/pos-custom-args/captures/byname.scala | 4 +- 24 files changed, 233 insertions(+), 84 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/cc/CapturingKind.scala create mode 100644 library/src-bootstrapped/scala/retainsByName.scala create mode 100644 tests/neg-custom-args/captures/byname.check diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 38030955b776..f6eae20206ba 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -454,7 +454,7 @@ object desugar { if mods.is(Trait) then for vparams <- originalVparamss; vparam <- vparams do - if vparam.tpt.isInstanceOf[ByNameTypeTree] then + if isByNameType(vparam.tpt) then report.error(em"implementation restriction: traits cannot have by name parameters", vparam.srcPos) // Annotations on class _type_ parameters are set on the derived parameters @@ -558,9 +558,8 @@ object desugar { appliedTypeTree(tycon, targs) } - def isRepeated(tree: Tree): Boolean = tree match { + def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match { case PostfixOp(_, Ident(tpnme.raw.STAR)) => true - case ByNameTypeTree(tree1) => isRepeated(tree1) case _ => false } @@ -1734,8 +1733,13 @@ object desugar { case ext: ExtMethods => Block(List(ext), Literal(Constant(())).withSpan(ext.span)) case CapturingTypeTree(refs, parent) => - val annot = New(scalaDot(tpnme.retains), List(refs)) - Annotated(parent, annot) + def annotate(annotName: TypeName, tp: Tree) = + Annotated(tp, New(scalaDot(annotName), List(refs))) + parent match + case ByNameTypeTree(restpt) => + cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt)) + case _ => + annotate(tpnme.retains, parent) } desugared.withSpan(tree.span) } diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 8592b6d2e647..09dc2efec965 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -172,8 +172,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] => } /** Is tpt a vararg type of the form T* or => T*? */ - def isRepeatedParamType(tpt: Tree)(using Context): Boolean = tpt match { - case ByNameTypeTree(tpt1) => isRepeatedParamType(tpt1) + def isRepeatedParamType(tpt: Tree)(using Context): Boolean = stripByNameType(tpt) match { case tpt: TypeTree => tpt.typeOpt.isRepeatedParam case AppliedTypeTree(Select(_, tpnme.REPEATED_PARAM_CLASS), _) => true case _ => false @@ -190,6 +189,16 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] => case arg => arg.typeOpt.widen.isRepeatedParam } + def isByNameType(tree: Tree)(using Context): Boolean = + stripByNameType(tree) ne tree + + def stripByNameType(tree: Tree)(using Context): Tree = unsplice(tree) match + case ByNameTypeTree(t1) => t1 + case untpd.CapturingTypeTree(_, parent) => + val parent1 = stripByNameType(parent) + if parent1 eq parent then tree else parent1 + case _ => tree + /** All type and value parameter symbols of this DefDef */ def allParamSyms(ddef: DefDef)(using Context): List[Symbol] = ddef.paramss.flatten.map(_.symbol) @@ -389,6 +398,16 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] case _ => None } } + + object ImpureByNameTypeTree: + def apply(tp: ByNameTypeTree)(using Context): untpd.CapturingTypeTree = + untpd.CapturingTypeTree( + Ident(nme.CAPTURE_ROOT).withSpan(tp.span.startPos) :: Nil, tp) + def unapply(tp: Tree)(using Context): Option[ByNameTypeTree] = tp match + case untpd.CapturingTypeTree(id @ Ident(nme.CAPTURE_ROOT) :: Nil, bntp: ByNameTypeTree) + if id.span == bntp.span.startPos => Some(bntp) + case _ => None + end ImpureByNameTypeTree } trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala index 5f73b50a6bbe..9f4f99ad52f3 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala @@ -12,7 +12,7 @@ import printing.Printer import printing.Texts.Text -case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation: +case class CaptureAnnotation(refs: CaptureSet, kind: CapturingKind) extends Annotation: import CaptureAnnotation.* import tpd.* @@ -25,17 +25,18 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio val arg = repeated(elems, TypeTree(defn.AnyType)) New(symbol.typeRef, arg :: Nil) - override def symbol(using Context) = defn.RetainsAnnot + override def symbol(using Context) = + if kind == CapturingKind.ByName then defn.RetainsByNameAnnot else defn.RetainsAnnot override def derivedAnnotation(tree: Tree)(using Context): Annotation = unsupported("derivedAnnotation(Tree)") - def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation = - if (this.refs eq refs) && (this.boxed == boxed) then this - else CaptureAnnotation(refs, boxed) + def derivedAnnotation(refs: CaptureSet, kind: CapturingKind)(using Context): Annotation = + if (this.refs eq refs) && (this.kind == kind) then this + else CaptureAnnotation(refs, kind) override def sameAnnotation(that: Annotation)(using Context): Boolean = that match - case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2 + case CaptureAnnotation(refs2, kind2) => refs == refs2 && kind == kind2 case _ => false override def mapWith(tp: TypeMap)(using Context) = @@ -43,7 +44,7 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio val elems1 = elems.mapConserve(tp) if elems1 eq elems then this else if elems1.forall(_.isInstanceOf[CaptureRef]) - then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed) + then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), kind) else EmptyAnnotation override def refersToParamOf(tl: TermLambda)(using Context): Boolean = @@ -54,10 +55,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio override def toText(printer: Printer): Text = refs.toText(printer) - override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0) + override def hash: Int = + (refs.hashCode << 1) | (if kind == CapturingKind.Regular then 0 else 1) override def eql(that: Annotation) = that match - case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed) + case that: CaptureAnnotation => (this.refs eq that.refs) && (this.kind == kind) case _ => false end CaptureAnnotation diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 4c201d7edf54..117b6e528e62 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -43,9 +43,9 @@ extension (tree: Tree) extension (tp: Type) def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match - case CapturingType(p, r, b) => + case CapturingType(p, r, k) => if (parent eq p) && (refs eq r) then tp - else CapturingType(parent, refs, b) + else CapturingType(parent, refs, k) /** If this is type variable instantiated or upper bounded with a capturing type, * the capture set associated with that type. Extended to and-or types and @@ -54,7 +54,8 @@ extension (tp: Type) */ def boxedCaptured(using Context): CaptureSet = def getBoxed(tp: Type): CaptureSet = tp match - case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty + case CapturingType(_, refs, CapturingKind.Boxed) => refs + case CapturingType(_, _, _) => CaptureSet.empty case tp: TypeProxy => getBoxed(tp.superType) case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2) case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 82e5e6e14a4b..6118e54174cd 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -209,8 +209,9 @@ sealed abstract class CaptureSet extends Showable: ((NoType: Type) /: elems) ((tp, ref) => if tp.exists then OrType(tp, ref, soft = false) else ref) - def toRegularAnnotation(using Context): Annotation = - Annotation(CaptureAnnotation(this, boxed = false).tree) + def toRegularAnnotation(byName: Boolean)(using Context): Annotation = + val kind = if byName then CapturingKind.ByName else CapturingKind.Regular + Annotation(CaptureAnnotation(this, kind).tree) override def toText(printer: Printer): Text = Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingKind.scala b/compiler/src/dotty/tools/dotc/cc/CapturingKind.scala new file mode 100644 index 000000000000..3bb00b110b21 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CapturingKind.scala @@ -0,0 +1,9 @@ +package dotty.tools +package dotc +package cc + +/** Possible kinds of captures */ +enum CapturingKind: + case Regular // normal capture + case Boxed // capture under box + case ByName // capture applies to enclosing by-name type (only possible before ElimByName) diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala index 738e746d0178..bca791e46205 100644 --- a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -5,26 +5,46 @@ package cc import core.* import Types.*, Symbols.*, Contexts.* +/** A capturing type. This is internally represented as an annotated type with a `retains` + * annotation, but the extractor will succeed only at phase CheckCaptures. + * Annotated types with `@retainsByName` annotation can also be created that way, by + * giving a `CapturingKind.ByName` as `kind` argument, but they are never extracted, + * since they have already been converted to regular capturing types before CheckCaptures. + */ object CapturingType: - def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type = + def apply(parent: Type, refs: CaptureSet, kind: CapturingKind)(using Context): Type = if refs.isAlwaysEmpty then parent - else AnnotatedType(parent, CaptureAnnotation(refs, boxed)) - - def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = - if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp) + else AnnotatedType(parent, CaptureAnnotation(refs, kind)) + + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] = + if ctx.phase == Phases.checkCapturesPhase then + val r = EventuallyCapturingType.unapply(tp) + r match + case Some((_, _, CapturingKind.ByName)) => None + case _ => r else None end CapturingType +/** An extractor for types that will be capturing types at phase CheckCaptures. Also + * included are types that indicate captures on enclosing call-by-name parameters + * before phase ElimByName + */ object EventuallyCapturingType: - def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = - if tp.annot.symbol == defn.RetainsAnnot then + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] = + val sym = tp.annot.symbol + if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then tp.annot match - case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed)) + case ann: CaptureAnnotation => + Some((tp.parent, ann.refs, ann.kind)) case ann => - try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + val kind = + if ann.tree.isBoxedCapturing then CapturingKind.Boxed + else if sym == defn.RetainsByNameAnnot then CapturingKind.ByName + else CapturingKind.Regular + try Some((tp.parent, ann.tree.toCaptureSet, kind)) catch case ex: IllegalCaptureRef => None else None diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index b5ccb002e209..622baff6f2a3 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -25,7 +25,8 @@ extends tpd.TreeTraverser: .toFunctionType(isJava = false, alwaysDependent = true) private def box(tp: Type)(using Context): Type = tp match - case CapturingType(parent, refs, false) => CapturingType(parent, refs, true) + case CapturingType(parent, refs, CapturingKind.Regular) => + CapturingType(parent, refs, CapturingKind.Boxed) case _ => tp private def setBoxed(tp: Type)(using Context) = tp match @@ -77,7 +78,7 @@ extends tpd.TreeTraverser: cls.paramGetters.foldLeft(tp) { (core, getter) => if getter.termRef.isTracked then val getterType = tp.memberInfo(getter).strippedDealias - RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), CapturingKind.Regular)) .showing(i"add capture refinement $tp --> $result", capt) else core @@ -130,7 +131,7 @@ extends tpd.TreeTraverser: case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) => CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2) case _ if canHaveInferredCapture(tp) => - CapturingType(tp, CaptureSet.Var(), boxed = false) + CapturingType(tp, CaptureSet.Var(), CapturingKind.Regular) case _ => tp diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 9b65ce971b64..2e76fc48700b 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -14,7 +14,7 @@ import typer.ImportInfo.RootRef import Comments.CommentsContext import Comments.Comment import util.Spans.NoSpan -import cc.{CapturingType, CaptureSet} +import cc.{CapturingType, CaptureSet, CapturingKind, EventuallyCapturingType} import scala.annotation.tailrec @@ -117,9 +117,9 @@ class Definitions { * * ErasedFunctionN and ErasedContextFunctionN erase to Function0. * - * EffXYZFunctionN afollow this template: + * ImpureXYZFunctionN follow this template: * - * type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R] + * type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R] */ private def newFunctionNType(name: TypeName): Symbol = { val impure = name.startsWith("Impure") @@ -135,7 +135,7 @@ class Definitions { HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)( tl => List.fill(arity + 1)(TypeBounds.empty), tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs), - CaptureSet.universal, boxed = false) + CaptureSet.universal, CapturingKind.Regular) )) else val cls = denot.asClass.classSymbol @@ -968,6 +968,7 @@ class Definitions { @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") @tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since") @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") + @tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") @@ -1101,9 +1102,16 @@ class Definitions { } } + /** Extractor for function types representing by-name parameters, of the form + * `() ?=> T`. + * Under -Ycc, this becomes `() ?-> T` or `{r1, ..., rN} () ?-> T`. + */ object ByNameFunction: - def apply(tp: Type)(using Context): Type = - defn.ContextFunction0.typeRef.appliedTo(tp :: Nil) + def apply(tp: Type)(using Context): Type = tp match + case EventuallyCapturingType(tp1, refs, CapturingKind.ByName) => + CapturingType(apply(tp1), refs, CapturingKind.Regular) + case _ => + defn.ContextFunction0.typeRef.appliedTo(tp :: Nil) def unapply(tp: Type)(using Context): Option[Type] = tp match case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) => Some(arg) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index bcb9f78c7ad6..4dde590d3e6a 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -580,6 +580,7 @@ object StdNames { val reify : N = "reify" val releaseFence : N = "releaseFence" val retains: N = "retains" + val retainsByName: N = "retainsByName" val rootMirror : N = "rootMirror" val run: N = "run" val runOrElse: N = "runOrElse" diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 112ae1e76e3c..4bde74a2ade8 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -24,7 +24,7 @@ import typer.Applications.productSelectorTypes import reporting.trace import NullOpsDecorator._ import annotation.constructorOnly -import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing} +import cc.{CapturingType, derivedCapturingType, CaptureSet, CapturingKind, stripCapturing} /** Provides methods to compare types. */ @@ -832,7 +832,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1 match case tp1: CaptureRef if tp1.isTracked => val stripped = tp1w.stripCapturing - tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false) + tp1w = CapturingType(stripped, tp1.singletonCaptureSet, CapturingKind.Regular) case _ => isSubType(tp1w, tp2, approx.addLow) } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index d0b919bf835b..7cd5841e5218 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,7 +38,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing } import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference -import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing} +import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing, CapturingKind} import CaptureSet.CompareResult import scala.annotation.internal.sharable @@ -1869,13 +1869,15 @@ object Types { def capturing(ref: CaptureRef)(using Context): Type = if captureSet.accountsFor(ref) then this - else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing) + else CapturingType(this, ref.singletonCaptureSet, + if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular) def capturing(cs: CaptureSet)(using Context): Type = if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this else this match case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs) - case _ => CapturingType(this, cs, this.isBoxedCapturing) + case _ => CapturingType(this, cs, + if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular) /** The set of distinct symbols referred to by this type, after all aliases are expanded */ def coveringSet(using Context): Set[Symbol] = @@ -3796,10 +3798,11 @@ object Types { CapturingType(parent1, CaptureSet.universal, boxed)) case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) => val parent1 = mapOver(parent) - if ann.symbol == defn.RetainsAnnot then + if ann.symbol == defn.RetainsAnnot || ann.symbol == defn.RetainsByNameAnnot then + val byName = ann.symbol == defn.RetainsByNameAnnot range( - AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation), - AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation)) + AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation(byName)), + AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation(byName))) else parent1 case _ => mapOver(tp) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 952a6022428e..55d0a3c8e2a8 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -1402,8 +1402,9 @@ object Parsers { val resultType = typ() if token == TLARROW then - for case ValDef(_, tpt: ByNameTypeTree, _) <- params do - syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span) + for case ValDef(_, tpt, _) <- params do + if isByNameType(tpt) then + syntaxError(em"parameter of type lambda may not be call-by-name", tpt.span) TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType) else if imods.isOneOf(Given | Erased | Impure) then if imods.is(Given) && params.isEmpty then @@ -1448,15 +1449,13 @@ object Parsers { if isValParamList || in.isArrow then functionRest(ts) else { - val ts1 = - for (t <- ts) yield - t match { - case t@ByNameTypeTree(t1) => - syntaxError(ByNameParameterNotSupported(t), t.span) - t1 - case _ => - t - } + val ts1 = ts.mapConserve { t => + if isByNameType(t) then + syntaxError(ByNameParameterNotSupported(t), t.span) + stripByNameType(t) + else + t + } val tuple = atSpan(start) { makeTupleOrParens(ts1) } infixTypeRest( refinedTypeRest( @@ -1793,17 +1792,48 @@ object Parsers { else commaSeparated(() => argType()) } - /** FunArgType ::= Type | `=>' Type + def paramTypeOf(core: () => Tree): Tree = + if in.token == ARROW || isIdent(nme.PUREARROW) then + val isImpure = in.token == ARROW + val tp = atSpan(in.skipToken()) { ByNameTypeTree(core()) } + if isImpure && ctx.settings.Ycc.value then ImpureByNameTypeTree(tp) else tp + else if in.token == LBRACE && followingIsCaptureSet() then + val start = in.offset + val cs = captureSet() + val endCsOffset = in.lastOffset + val startTpOffset = in.offset + val tp = paramTypeOf(core) + val tp1 = tp match + case ImpureByNameTypeTree(tp1) => + syntaxError("explicit captureSet is superfluous for impure call-by-name type", start) + tp1 + case CapturingTypeTree(_, tp1: ByNameTypeTree) => + syntaxError("only one captureSet is allowed here", start) + tp1 + case _: ByNameTypeTree if startTpOffset > endCsOffset => + report.warning( + i"""Style: by-name `->` should immediately follow closing `}` of capture set + |to avoid confusion with function type. + |That is, `{c}-> T` instead of `{c} -> T`.""", + source.atSpan(Span(startTpOffset, startTpOffset))) + tp + case _ => + tp + CapturingTypeTree(cs, tp1) + else + core() + + /** FunArgType ::= Type + * | `=>' Type + * | [CaptureSet] `->' Type */ - val funArgType: () => Tree = () => - if (in.token == ARROW) atSpan(in.skipToken()) { ByNameTypeTree(typ()) } - else typ() + val funArgType: () => Tree = () => paramTypeOf(typ) - /** ParamType ::= [`=>'] ParamValueType + /** ParamType ::= ParamValueType + * | `=>' ParamValueType + * | [CaptureSet] `->' ParamValueType */ - def paramType(): Tree = - if (in.token == ARROW) atSpan(in.skipToken()) { ByNameTypeTree(paramValueType()) } - else paramValueType() + def paramType(): Tree = paramTypeOf(paramValueType) /** ParamValueType ::= Type [`*'] */ @@ -3071,6 +3101,7 @@ object Parsers { acceptColon() if (in.token == ARROW && ofClass && !mods.is(Local)) syntaxError(VarValParametersMayNotBeCallByName(name, mods.is(Mutable))) + // needed?, it's checked later anyway val tpt = paramType() val default = if (in.token == EQUALS) { in.nextToken(); subExpr() } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 6409d37ef735..bca08f54ade7 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -15,7 +15,7 @@ import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch import config.Config -import cc.{EventuallyCapturingType, CaptureSet} +import cc.{CapturingType, EventuallyCapturingType, CaptureSet, CapturingKind} class PlainPrinter(_ctx: Context) extends Printer { @@ -200,8 +200,8 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close - case EventuallyCapturingType(parent, refs, boxed) => - def box = Str("box ") provided boxed + case EventuallyCapturingType(parent, refs, kind) => + def box = Str("box ") provided kind == CapturingKind.Boxed if printDebug && !refs.isConst then changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent)) else if ctx.settings.YccDebug.value then @@ -232,8 +232,13 @@ class PlainPrinter(_ctx: Context) extends Printer { ~ (if tp.resultType.isInstanceOf[MethodType] then ")" else "): ") ~ toText(tp.resultType) } - case tp: ExprType => - changePrec(GlobalPrec) { "=> " ~ toText(tp.resultType) } + case ExprType(ct @ EventuallyCapturingType(parent, refs, CapturingKind.ByName)) => + if refs.isUniversal then changePrec(GlobalPrec) { "=> " ~ toText(parent) } + else toText(CapturingType(ExprType(parent), refs, CapturingKind.Regular)) + case ExprType(restp) => + changePrec(GlobalPrec) { + (if ctx.settings.Ycc.value then "-> " else "=> ") ~ toText(restp) + } case tp: HKTypeLambda => changePrec(GlobalPrec) { "[" ~ paramsText(tp) ~ "]" ~ lambdaHash(tp) ~ Str(" =>> ") ~ toTextGlobal(tp.resultType) diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 06cf18b9dcef..e361cf122c5f 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -554,7 +554,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { (" <: " ~ toText(bound) provided !bound.isEmpty) } case ByNameTypeTree(tpt) => - "=> " ~ toTextLocal(tpt) + (if ctx.settings.Ycc.value then "-> " else "=> ") + ~ toTextLocal(tpt) case TypeBoundsTree(lo, hi, alias) => if (lo eq hi) && alias.isEmpty then optText(lo)(" = " ~ _) else optText(lo)(" >: " ~ _) ~ optText(hi)(" <: " ~ _) ~ optText(alias)(" = " ~ _) @@ -719,7 +720,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val argsText = toTextGlobal(args, ", ") prefix ~~ idx.toString ~~ "|" ~~ argsText ~~ postfix case CapturingTypeTree(refs, parent) => - changePrec(GlobalPrec)("{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent)) + parent match + case ImpureByNameTypeTree(bntpt) => + "=> " ~ toTextLocal(bntpt) + case _ => + changePrec(GlobalPrec)("{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent)) case _ => tree.fallbackToText(this) } diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 96676f04ce99..2087af10fd07 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -666,7 +666,7 @@ import transform.SymUtils._ } } - class ByNameParameterNotSupported(tpe: untpd.TypTree)(using Context) + class ByNameParameterNotSupported(tpe: untpd.Tree)(using Context) extends SyntaxMsg(ByNameParameterNotSupportedID) { def msg = em"By-name parameter type ${tpe} not allowed here." diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index f64b970b5c80..55e3cefefdce 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -125,13 +125,15 @@ object ErrorReporting { def typeMismatch(tree: Tree, pt: Type, implicitFailure: SearchFailureType = NoMatchingImplicits): Tree = { val normTp = normalize(tree.tpe, pt) - val treeTp = if (normTp <:< pt) tree.tpe else normTp - // use normalized type if that also shows an error, original type otherwise + val normPt = normalize(pt, pt) + val (treeTp, expectedTp) = + if (normTp <:< normPt) (tree.tpe, pt) else (normTp, normPt) + // use normalized types if that also shows an error, original types otherwise def missingElse = tree match case If(_, _, elsep @ Literal(Constant(()))) if elsep.span.isSynthetic => "\nMaybe you are missing an else part for the conditional?" case _ => "" - errorTree(tree, TypeMismatch(treeTp, pt, Some(tree), implicitFailure.whyNoConversion, missingElse)) + errorTree(tree, TypeMismatch(treeTp, expectedTp, Some(tree), implicitFailure.whyNoConversion, missingElse)) } /** A subtype log explaining why `found` does not conform to `expected` */ diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 950afa76a2e3..1ff616f07bf3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2598,7 +2598,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer registerNowarn(annot1, tree) val arg1 = typed(tree.arg, pt) if (ctx.mode is Mode.Type) { - if annot1.symbol.maybeOwner == defn.RetainsAnnot then + val cls = annot1.symbol.maybeOwner + if cls == defn.RetainsAnnot || cls == defn.RetainsByNameAnnot then CheckCaptures.checkWellformed(annot1) if arg1.isType then assignType(cpy.Annotated(tree)(arg1, annot1), arg1, annot1) diff --git a/library/src-bootstrapped/scala/retainsByName.scala b/library/src-bootstrapped/scala/retainsByName.scala new file mode 100644 index 000000000000..c530f35ec0e4 --- /dev/null +++ b/library/src-bootstrapped/scala/retainsByName.scala @@ -0,0 +1,6 @@ +package scala + +/** An annotation that indicates capture of an enclosing by-name type + */ +class retainsByName(xs: Any*) extends annotation.StaticAnnotation + diff --git a/tests/neg-custom-args/captures/byname.check b/tests/neg-custom-args/captures/byname.check new file mode 100644 index 000000000000..d8d5d689efae --- /dev/null +++ b/tests/neg-custom-args/captures/byname.check @@ -0,0 +1,20 @@ +-- Warning: tests/neg-custom-args/captures/byname.scala:14:18 ---------------------------------------------------------- +14 | def h(x: {cap1} -> I) = x // warning + | ^ + | Style: by-name `->` should immediately follow closing `}` of capture set + | to avoid confusion with function type. + | That is, `{c}-> T` instead of `{c} -> T`. +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:7:5 ----------------------------------------- +7 | h(f()) // error + | ^^^ + | Found: {cap2} (x$0: Int) -> Int + | Required: Int -> Int + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:16:5 ---------------------------------------- +16 | h(g()) // error + | ^^^ + | Found: {cap2} () ?-> I + | Required: {cap1} () ?-> I + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala index feb9461dc4c7..0afca7e261e2 100644 --- a/tests/neg-custom-args/captures/byname.scala +++ b/tests/neg-custom-args/captures/byname.scala @@ -6,4 +6,14 @@ def test(cap1: Cap, cap2: Cap) = def h(ff: => {cap2} Int -> Int) = ff h(f()) // error +class I + +def test2(cap1: Cap, cap2: Cap): {cap1} I = + def f() = if cap1 == cap1 then I() else I() + def g() = if cap2 == cap2 then I() else I() + def h(x: {cap1} -> I) = x // warning + h(f()) // OK + h(g()) // error + + diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check index 0de190df8f11..31624e437928 100644 --- a/tests/neg-custom-args/captures/lazylist.check +++ b/tests/neg-custom-args/captures/lazylist.check @@ -1,8 +1,8 @@ -- [E163] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ---------------------------------------- 22 | def tail: {*} LazyList[Nothing] = ??? // error overriding | ^ - | error overriding method tail in class LazyList of type => lazylists.LazyList[Nothing]; - | method tail of type => {*} lazylists.LazyList[Nothing] has incompatible type + | error overriding method tail in class LazyList of type -> lazylists.LazyList[Nothing]; + | method tail of type -> {*} lazylists.LazyList[Nothing] has incompatible type longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 ------------------------------------- diff --git a/tests/neg-custom-args/captures/lazylists2.check b/tests/neg-custom-args/captures/lazylists2.check index 8e09dd26cccf..a8f145ecd9d7 100644 --- a/tests/neg-custom-args/captures/lazylists2.check +++ b/tests/neg-custom-args/captures/lazylists2.check @@ -1,8 +1,8 @@ -- [E163] Declaration Error: tests/neg-custom-args/captures/lazylists2.scala:50:10 ------------------------------------- 50 | def tail: {xs, f} LazyList[B] = xs.tail.map(f) // error | ^ - | error overriding method tail in trait LazyList of type => {Mapped.this} LazyList[B]; - | method tail of type => {xs, f} LazyList[B] has incompatible type + | error overriding method tail in trait LazyList of type -> {Mapped.this} LazyList[B]; + | method tail of type -> {xs, f} LazyList[B] has incompatible type longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:18:4 ------------------------------------ diff --git a/tests/pos-custom-args/captures/byname.scala b/tests/pos-custom-args/captures/byname.scala index a3d80f31d579..5cb5255d4652 100644 --- a/tests/pos-custom-args/captures/byname.scala +++ b/tests/pos-custom-args/captures/byname.scala @@ -5,6 +5,6 @@ class I def test(cap1: Cap, cap2: Cap): {cap1} I = def f() = if cap1 == cap1 then I() else I() - def h(x: /*=>*/ {cap1} I) = x // TODO: enable cbn - h(f()) + def h(x: {cap1}-> I) = x + h(f()) // OK From 91299b24076c526ba112f0e8bc18fb9ec7082893 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 25 Jan 2022 19:40:04 +0100 Subject: [PATCH 18/24] Treat exceptions as capabilities 1. Make CanThrow a @capability class 2. Fix pure arrow handling in parser 3. Avoid misleading type mismatch message 4. Make map and filter conserve Const capturesets if there's no change 5. Expand $throws clauses to context function types 6. Exempt compiletime.erasedValue for "no '*'" checks 7. Capability escape checking for try --- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 13 +++- .../dotty/tools/dotc/cc/CapturingType.scala | 2 +- compiler/src/dotty/tools/dotc/cc/Setup.scala | 36 +++++++++-- .../dotty/tools/dotc/parsing/Parsers.scala | 7 ++- .../dotty/tools/dotc/parsing/Scanners.scala | 3 + .../dotty/tools/dotc/transform/Recheck.scala | 5 +- .../tools/dotc/typer/CheckCaptures.scala | 45 ++++++++----- .../tools/dotc/typer/ErrorReporting.scala | 16 ++++- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- .../tools/dotc/util/SimpleIdentitySet.scala | 4 ++ library/src/scala/CanThrow.scala | 4 +- tests/neg-custom-args/captures/real-try.check | 8 +++ tests/neg-custom-args/captures/real-try.scala | 14 +++++ tests/pos-custom-args/captures/i13816.scala | 63 +++++++++++++++++++ 14 files changed, 188 insertions(+), 34 deletions(-) create mode 100644 tests/neg-custom-args/captures/real-try.check create mode 100644 tests/neg-custom-args/captures/real-try.scala create mode 100644 tests/pos-custom-args/captures/i13816.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 6118e54174cd..b8daef92beef 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -173,7 +173,10 @@ sealed abstract class CaptureSet extends Showable: this -- ref.singletonCaptureSet def filter(p: CaptureRef => Boolean)(using Context): CaptureSet = - if this.isConst then Const(elems.filter(p)) + if this.isConst then + val elems1 = elems.filter(p) + if elems1 == elems then this + else Const(elems.filter(p)) else Filtered(asVar, p) /** capture set obtained by applying `f` to all elements of the current capture set @@ -183,11 +186,15 @@ sealed abstract class CaptureSet extends Showable: def map(tm: TypeMap)(using Context): CaptureSet = tm match case tm: BiTypeMap => val mappedElems = elems.map(tm.forward) - if isConst then Const(mappedElems) + if isConst then + if mappedElems == elems then this + else Const(mappedElems) else BiMapped(asVar, tm, mappedElems) case _ => val mapped = mapRefs(elems, tm, tm.variance) - if isConst then mapped + if isConst then + if mapped.isConst && mapped.elems == elems then this + else mapped else Mapped(asVar, tm, tm.variance, mapped) def substParams(tl: BindingType, to: List[Type])(using Context) = diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala index bca791e46205..d19850b72e4f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -29,7 +29,7 @@ end CapturingType /** An extractor for types that will be capturing types at phase CheckCaptures. Also * included are types that indicate captures on enclosing call-by-name parameters - * before phase ElimByName + * before phase ElimByName. */ object EventuallyCapturingType: diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 622baff6f2a3..4197d097d8aa 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -44,6 +44,29 @@ extends tpd.TreeTraverser: case _ => traverseChildren(t) + /** Expand some aliases of function types to the underlying functions. + * Right now, these are only $throws aliases, but this could be generalized. + */ + def expandInlineAlias(tp: Type)(using Context) = tp match + case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias => + // hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->` + defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true) + case _ => tp + + private def expandInlineAliases(using Context) = new TypeMap: + def apply(t: Type) = t match + case _: AppliedType => + val t1 = expandInlineAlias(t) + if t1 ne t then apply(t1) else mapOver(t) + case _: LazyRef => + t + case t @ AnnotatedType(t1, ann) => + // Don't map capture sets, since that would implicitly normalize sets that + // are not well-formed. + t.derivedAnnotatedType(apply(t1), ann) + case _ => + mapOver(t) + /** Perform the following transformation steps everywhere in a type: * 1. Drop retains annotations * 2. Turn plain function types into dependent function types, so that @@ -143,7 +166,8 @@ extends tpd.TreeTraverser: try ts.mapConserve(this) finally isTopLevel = saved def apply(t: Type) = - val t1 = t match + val tp = expandInlineAlias(t) + val tp1 = tp match case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => apply(parent) case tp @ AppliedType(tycon, args) => @@ -172,8 +196,8 @@ extends tpd.TreeTraverser: paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds), resType = this(tp.resType)) case _ => - mapOver(t) - addVar(addCaptureRefinements(t1)) + mapOver(tp) + addVar(addCaptureRefinements(tp1)) end mapInferred private def expandAbbreviations(using Context) = new TypeMap: @@ -232,8 +256,10 @@ extends tpd.TreeTraverser: private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type = addBoxes.traverse(tp) if boxed then setBoxed(tp) - if ctx.settings.YccNoAbbrev.value then tp - else expandAbbreviations(tp) + val tp1 = expandInlineAliases(tp) + if tp1 ne tp then capt.println(i"expanded: $tp --> $tp1") + if ctx.settings.YccNoAbbrev.value then tp1 + else expandAbbreviations(tp1) // Substitute parameter symbols in `from` to paramRefs in corresponding // method or poly types `to`. We use a single BiTypeMap to do everything. diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 55d0a3c8e2a8..cf03ea4e1a5b 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -421,7 +421,10 @@ object Parsers { /** Convert tree to formal parameter list */ def convertToParams(tree: Tree): List[ValDef] = - val mods = if in.token == CTXARROW then Modifiers(Given) else EmptyModifiers + val mods = + if in.token == CTXARROW || in.isIdent(nme.PURECTXARROW) + then Modifiers(Given) + else EmptyModifiers tree match case Parens(t) => convertToParam(t, mods) :: Nil @@ -1446,7 +1449,7 @@ object Parsers { funTypeArgsRest(t, funArgType) } accept(RPAREN) - if isValParamList || in.isArrow then + if isValParamList || in.isArrow || in.isPureArrow then functionRest(ts) else { val ts1 = ts.mapConserve { t => diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index 1a2f3cd3d86a..b78d864de7df 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -93,6 +93,9 @@ object Scanners { def isArrow = token == ARROW || token == CTXARROW + + def isPureArrow = + isIdent(nme.PUREARROW) || isIdent(nme.PURECTXARROW) } abstract class ScannerCommon(source: SourceFile)(using Context) extends CharArrayReader with TokenData { diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index bdb49c5a8edb..33de86d091ef 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -337,7 +337,9 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = checkConforms(tpe, pt, tree) - if keepTypes then tree.rememberType(tpe) + if keepTypes + || tree.isInstanceOf[Try] // type needs tp be checked for * escapes + then tree.rememberType(tpe) tpe def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = @@ -363,6 +365,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: || expected.isRepeatedParam && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) if !isCompatible then + recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected") err.typeMismatch(tree.withType(tpe), expected) else if debugSuccesses then tree match diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index bfe93514548e..c4e62e66ac75 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -16,6 +16,7 @@ import transform.Recheck import Recheck.* import scala.collection.mutable import CaptureSet.withCaptureSetsExplained +import reporting.trace object CheckCaptures: import ast.tpd.* @@ -75,7 +76,15 @@ object CheckCaptures: if remaining.accountsFor(firstRef) then report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos) - private inline val disallowGlobal = true + /** Does this function allow type arguments carrying the universal capability? + * Currently this is true only for `erasedValue` since this function is magic in + * that is allows to conjure global capabilies from nothing (aside: can we find a + * more controlled way to achieve this?). + * But it could be generalized to other functions that so that they can take capability + * classes as arguments. + */ + private def allowUniversalArguments(fn: Tree)(using Context): Boolean = + fn.symbol == defn.Compiletime_erasedValue class CheckCaptures extends Recheck: thisPhase => @@ -305,13 +314,13 @@ class CheckCaptures extends Recheck: .traverse(ctx.compilationUnit.tpdTree) withCaptureSetsExplained { super.checkUnit(unit) - PostRefinerCheck.traverse(unit.tpdTree) + PostCheck.traverse(unit.tpdTree) if ctx.settings.YccDebug.value then show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing } - def checkNotGlobal(tree: Tree, tp: Type, allArgs: Tree*)(using Context): Unit = - for ref <-tp.captureSet.elems do + def checkNotGlobal(tree: Tree, tp: Type, isVar: Boolean, allArgs: Tree*)(using Context): Unit = + for ref <- tp.captureSet.elems do val isGlobal = ref match case ref: TermRef => ref.isRootCapability case _ => false @@ -320,7 +329,7 @@ class CheckCaptures extends Recheck: val notAllowed = i" is not allowed to capture the $what capability $ref" def msg = if allArgs.isEmpty then - i"type of mutable variable ${tree.knownType}$notAllowed" + i"${if isVar then "type of mutable variable" else "result type"} ${tree.knownType}$notAllowed" else tree match case tree: InferredTypeTree => i"""inferred type argument ${tree.knownType}$notAllowed @@ -330,12 +339,11 @@ class CheckCaptures extends Recheck: report.error(msg, tree.srcPos) def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = - if disallowGlobal then - tree match - case LambdaTypeTree(_, restpt) => - checkNotGlobal(restpt, allArgs*) - case _ => - checkNotGlobal(tree, tree.knownType, allArgs*) + tree match + case LambdaTypeTree(_, restpt) => + checkNotGlobal(restpt, allArgs*) + case _ => + checkNotGlobal(tree, tree.knownType, isVar = false, allArgs*) def checkNotGlobalDeep(tree: Tree)(using Context): Unit = val checker = new TypeTraverser: @@ -346,12 +354,12 @@ class CheckCaptures extends Recheck: case _ => case tp: TermRef => case _ => - checkNotGlobal(tree, tp) + checkNotGlobal(tree, tp, isVar = true) traverseChildren(tp) checker.traverse(tree.knownType) - object PostRefinerCheck extends TreeTraverser: - def traverse(tree: Tree)(using Context) = + object PostCheck extends TreeTraverser: + def traverse(tree: Tree)(using Context) = trace{i"post check $tree"} { tree match case _: InferredTypeTree => case tree: TypeTree if !tree.span.isZeroExtent => @@ -362,7 +370,7 @@ class CheckCaptures extends Recheck: checkWellformedPost(annot.tree) case _ => } - case tree1 @ TypeApply(fn, args) if disallowGlobal => + case tree1 @ TypeApply(fn, args) if !allowUniversalArguments(fn) => for arg <- args do //println(i"checking $arg in $tree: ${tree.knownType.captureSet}") checkNotGlobal(arg, args*) @@ -390,11 +398,14 @@ class CheckCaptures extends Recheck: inferred.foreachPart(checkPure, StopAt.Static) case t: ValDef if t.symbol.is(Mutable) => checkNotGlobalDeep(t.tpt) + case t: Try => + checkNotGlobal(t) case _ => traverseChildren(tree) + } - def postRefinerCheck(tree: tpd.Tree)(using Context): Unit = - PostRefinerCheck.traverse(tree) + def postCheck(tree: tpd.Tree)(using Context): Unit = + PostCheck.traverse(tree) end CaptureChecker end CheckCaptures diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 55e3cefefdce..ac85a0f9a432 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -126,13 +126,25 @@ object ErrorReporting { def typeMismatch(tree: Tree, pt: Type, implicitFailure: SearchFailureType = NoMatchingImplicits): Tree = { val normTp = normalize(tree.tpe, pt) val normPt = normalize(pt, pt) + + def contextFunctionCount(tp: Type): Int = tp.stripped match + case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp) + case _ => 0 + def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp) + def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt) + val (treeTp, expectedTp) = - if (normTp <:< normPt) (tree.tpe, pt) else (normTp, normPt) - // use normalized types if that also shows an error, original types otherwise + if normTp <:< normPt || strippedTpCount != strippedPtCount + then (tree.tpe, pt) + else (normTp, normPt) + // use normalized types if that also shows an error, and both sides stripped + // the same number of context functions. Use original types otherwise. + def missingElse = tree match case If(_, _, elsep @ Literal(Constant(()))) if elsep.span.isSynthetic => "\nMaybe you are missing an else part for the conditional?" case _ => "" + errorTree(tree, TypeMismatch(treeTp, expectedTp, Some(tree), implicitFailure.whyNoConversion, missingElse)) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 1ff616f07bf3..dfe5cbf40f13 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2306,7 +2306,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer //todo: make sure dependent method types do not depend on implicits or by-name params } - /** (1) Check that the signature of the class mamber does not return a repeated parameter type + /** (1) Check that the signature of the class member does not return a repeated parameter type * (2) If info is an erased class, set erased flag of member */ private def postProcessInfo(sym: Symbol)(using Context): Unit = diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index 1fac0dac0913..1508a069cdab 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -36,6 +36,10 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { ((SimpleIdentitySet.empty: SimpleIdentitySet[E]) /: this) { (s, x) => if (that.contains(x)) s else s + x } + + def == [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): Boolean = + this.size == that.size && forall(that.contains) + override def toString: String = toList.mkString("{", ", ", "}") } diff --git a/library/src/scala/CanThrow.scala b/library/src/scala/CanThrow.scala index fcfd11fc9197..c7f23a393715 100644 --- a/library/src/scala/CanThrow.scala +++ b/library/src/scala/CanThrow.scala @@ -1,12 +1,12 @@ package scala import language.experimental.erasedDefinitions -import annotation.{implicitNotFound, experimental} +import annotation.{implicitNotFound, experimental, capability} /** A capability class that allows to throw exception `E`. When used with the * experimental.saferExceptions feature, a `throw Ex()` expression will require * a given of class `CanThrow[Ex]` to be available. */ -@experimental +@experimental @capability @implicitNotFound("The capability to throw exception ${E} is missing.\nThe capability can be provided by one of the following:\n - Adding a using clause `(using CanThrow[${E}])` to the definition of the enclosing method\n - Adding `throws ${E}` clause after the result type of the enclosing method\n - Wrapping this piece of code with a `try` block that catches ${E}") erased class CanThrow[-E <: Exception] diff --git a/tests/neg-custom-args/captures/real-try.check b/tests/neg-custom-args/captures/real-try.check new file mode 100644 index 000000000000..11a6fdfd50dd --- /dev/null +++ b/tests/neg-custom-args/captures/real-try.check @@ -0,0 +1,8 @@ +-- Error: tests/neg-custom-args/captures/real-try.scala:10:2 ----------------------------------------------------------- +10 | try // error + | ^ + | result type {*} () -> Unit is not allowed to capture the universal capability *.type +11 | () => foo(1) +12 | catch +13 | case _: Ex1 => ??? +14 | case _: Ex2 => ??? diff --git a/tests/neg-custom-args/captures/real-try.scala b/tests/neg-custom-args/captures/real-try.scala new file mode 100644 index 000000000000..9a8ccd694dc9 --- /dev/null +++ b/tests/neg-custom-args/captures/real-try.scala @@ -0,0 +1,14 @@ +import language.experimental.saferExceptions + +class Ex1 extends Exception("Ex1") +class Ex2 extends Exception("Ex2") + +def foo(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def test() = + try // error + () => foo(1) + catch + case _: Ex1 => ??? + case _: Ex2 => ??? diff --git a/tests/pos-custom-args/captures/i13816.scala b/tests/pos-custom-args/captures/i13816.scala new file mode 100644 index 000000000000..b8f9db405188 --- /dev/null +++ b/tests/pos-custom-args/captures/i13816.scala @@ -0,0 +1,63 @@ +import language.experimental.saferExceptions + +class Ex1 extends Exception("Ex1") +class Ex2 extends Exception("Ex2") + +def foo0(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo01(i: Int): CanThrow[Ex1] ?-> CanThrow[Ex2] ?-> Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo1(i: Int): Unit throws Ex1 throws Ex2 = + if i > 0 then throw new Ex1 else throw new Ex1 + +def foo2(i: Int): Unit throws Ex1 | Ex2 = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo3(i: Int): Unit throws (Ex1 | Ex2) = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo4(i: Int)(using CanThrow[Ex1], CanThrow[Ex2]): Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo5(i: Int)(using CanThrow[Ex1])(using CanThrow[Ex2]): Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo6(i: Int)(using CanThrow[Ex1 | Ex2]): Unit = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex2 = + if i > 0 then throw new Ex1 else throw new Ex2 + +def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex1 = + if i > 0 then throw new Ex1 else throw new Ex2 + +def test(): Unit = + try + foo1(1) + foo2(1) + foo3(1) + foo4(1) + foo5(1) + foo6(1) + foo7(1) + foo8(1) + catch + case _: Ex1 => + case _: Ex2 => + + try + try + foo1(1) + foo2(1) + foo3(1) + foo4(1) + foo5(1) + // foo6(1) // As explained in the docs this won't work until we find a way to aggregate capabilities + foo7(1) + foo8(1) + catch + case _: Ex1 => + catch + case _: Ex2 => From 9033e0c171a01487a3ae5a202193ad056caf7d07 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 26 Jan 2022 12:52:33 +0100 Subject: [PATCH 19/24] Mark classes compiled under -Ycc with a CaptureChecked annotation --- .../src/dotty/tools/dotc/core/Definitions.scala | 1 + .../src/dotty/tools/dotc/transform/PostTyper.scala | 14 +++++++------- .../scala/annotation/internal/CaptureChecked.scala | 8 ++++++++ 3 files changed, 16 insertions(+), 7 deletions(-) create mode 100644 library/src/scala/annotation/internal/CaptureChecked.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 2e76fc48700b..974fd3f2ab0d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -923,6 +923,7 @@ class Definitions { @tu lazy val BooleanBeanPropertyAnnot: ClassSymbol = requiredClass("scala.beans.BooleanBeanProperty") @tu lazy val BodyAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Body") @tu lazy val CapabilityAnnot: ClassSymbol = requiredClass("scala.annotation.capability") + @tu lazy val CaptureCheckedAnnot: ClassSymbol = requiredClass("scala.annotation.internal.CaptureChecked") @tu lazy val ChildAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Child") @tu lazy val ContextResultCountAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ContextResultCount") @tu lazy val ProvisionalSuperClassAnnot: ClassSymbol = requiredClass("scala.annotation.internal.ProvisionalSuperClass") diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 54bb72275921..24afec3ff22d 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -370,13 +370,13 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase for parent <- impl.parents do Checking.checkTraitInheritance(parent.tpe.classSymbol, sym.asClass, parent.srcPos) // Add SourceFile annotation to top-level classes - if sym.owner.is(Package) - && ctx.compilationUnit.source.exists - && sym != defn.SourceFileAnnot - then - val reference = ctx.settings.sourceroot.value - val relativePath = util.SourceFile.relativePath(ctx.compilationUnit.source, reference) - sym.addAnnotation(Annotation.makeSourceFile(relativePath)) + if sym.owner.is(Package) then + if ctx.compilationUnit.source.exists && sym != defn.SourceFileAnnot then + val reference = ctx.settings.sourceroot.value + val relativePath = util.SourceFile.relativePath(ctx.compilationUnit.source, reference) + sym.addAnnotation(Annotation.makeSourceFile(relativePath)) + if ctx.settings.Ycc.value && sym != defn.CaptureCheckedAnnot then + sym.addAnnotation(Annotation(defn.CaptureCheckedAnnot)) else (tree.rhs, sym.info) match case (rhs: LambdaTypeTree, bounds: TypeBounds) => VarianceChecker.checkLambda(rhs, bounds) diff --git a/library/src/scala/annotation/internal/CaptureChecked.scala b/library/src/scala/annotation/internal/CaptureChecked.scala new file mode 100644 index 000000000000..3ffea31b898c --- /dev/null +++ b/library/src/scala/annotation/internal/CaptureChecked.scala @@ -0,0 +1,8 @@ +package scala.annotation +package internal + +/** A marker annotation on a toplevel class that indicates + * that the class was checked under -Ycc + */ +class CaptureChecked extends StaticAnnotation + From 6d672b15d48a88492288efa8d5788899359918c6 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 26 Jan 2022 15:09:46 +0100 Subject: [PATCH 20/24] Map regular function types to impure function types when unpickling Map regular function types to impure function types when unpickling a class under -Ycc that was not itself compiled with -Ycc. --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 16 +++++++++++++++- .../tools/dotc/core/tasty/TreeUnpickler.scala | 16 ++++++++++++++-- .../core/unpickleScala2/Scala2Unpickler.scala | 5 ++++- .../captures/capt-separate/Lib_1.scala | 6 ++++++ .../captures/capt-separate/Test_2.scala | 15 +++++++++++++++ .../captures/saferExceptions.scala | 18 ++++++++++++++++++ 6 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 tests/pos-custom-args/captures/capt-separate/Lib_1.scala create mode 100644 tests/pos-custom-args/captures/capt-separate/Test_2.scala create mode 100644 tests/pos-custom-args/captures/saferExceptions.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 117b6e528e62..aa7efe2a04cf 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -5,7 +5,7 @@ package cc import core.* import Types.*, Symbols.*, Contexts.*, Annotations.* import ast.{tpd, untpd} -import Decorators.* +import Decorators.*, NameOps.* import config.Printers.capt import util.Property.Key import tpd.* @@ -71,3 +71,17 @@ extension (tp: Type) atd.derivedAnnotatedType(parent.stripCapturing, annot) case _ => tp + + /** Under -Ycc, map regular function type to impure function type + */ + def adaptFunctionType(using Context): Type = tp match + case AppliedType(fn, args) + if ctx.settings.Ycc.value && defn.isFunctionClass(fn.typeSymbol) => + val fname = fn.typeSymbol.name + defn.FunctionType( + fname.functionArity, + isContextual = fname.isContextFunction, + isErased = fname.isErasedFunction, + isImpure = true).appliedTo(args) + case _ => + tp diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index a217b76944fd..84359cbb37d8 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -31,6 +31,7 @@ import ast.{TreeTypeMap, Trees, tpd, untpd} import Trees._ import Decorators._ import transform.SymUtils._ +import cc.adaptFunctionType import dotty.tools.tasty.{TastyBuffer, TastyReader} import TastyBuffer._ @@ -87,6 +88,9 @@ class TreeUnpickler(reader: TastyReader, /** The root owner tree. See `OwnerTree` class definition. Set by `enterTopLevel`. */ private var ownerTree: OwnerTree = _ + /** Was unpickled class compiled with -Ycc? */ + private var wasCaptureChecked: Boolean = false + private def registerSym(addr: Addr, sym: Symbol) = symAtAddr(addr) = sym @@ -357,7 +361,7 @@ class TreeUnpickler(reader: TastyReader, // Note that the lambda "rt => ..." is not equivalent to a wildcard closure! // Eta expansion of the latter puts readType() out of the expression. case APPLIEDtype => - readType().appliedTo(until(end)(readType())) + postProcessFunction(readType().appliedTo(until(end)(readType()))) case TYPEBOUNDS => val lo = readType() if nothingButMods(end) then @@ -470,6 +474,12 @@ class TreeUnpickler(reader: TastyReader, def readTermRef()(using Context): TermRef = readType().asInstanceOf[TermRef] + /** Under -Ycc, map all function types to impure function types, + * unless the unpickled class was also compiled with -Ycc. + */ + private def postProcessFunction(tp: Type)(using Context): Type = + if wasCaptureChecked then tp else tp.adaptFunctionType + // ------ Reading definitions ----------------------------------------------------- private def nothingButMods(end: Addr): Boolean = @@ -605,6 +615,8 @@ class TreeUnpickler(reader: TastyReader, } registerSym(start, sym) if (isClass) { + if sym.owner.is(Package) && annots.exists(_.symbol == defn.CaptureCheckedAnnot) then + wasCaptureChecked = true sym.completer.withDecls(newScope) forkAt(templateStart).indexTemplateParams()(using localContext(sym)) } @@ -1265,7 +1277,7 @@ class TreeUnpickler(reader: TastyReader, val args = until(end)(readTpt()) val tree = untpd.AppliedTypeTree(tycon, args) val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes)) - tree.withType(ownType) + tree.withType(postProcessFunction(ownType)) case ANNOTATEDtpt => Annotated(readTpt(), readTerm()) case LAMBDAtpt => diff --git a/compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala b/compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala index b5b8c4715ebc..693e819b2d0b 100644 --- a/compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala @@ -30,6 +30,7 @@ import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.annotation.switch import reporting._ +import cc.adaptFunctionType object Scala2Unpickler { @@ -818,7 +819,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas // special-case in erasure, see TypeErasure#eraseInfo. OrType(args(0), args(1), soft = false) } - else if (args.nonEmpty) tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly))) + else if args.nonEmpty then + tycon.safeAppliedTo(EtaExpandIfHK(sym.typeParams, args.map(translateTempPoly))) + .adaptFunctionType else if (sym.typeParams.nonEmpty) tycon.EtaExpand(sym.typeParams) else tycon case TYPEBOUNDStpe => diff --git a/tests/pos-custom-args/captures/capt-separate/Lib_1.scala b/tests/pos-custom-args/captures/capt-separate/Lib_1.scala new file mode 100644 index 000000000000..c620ebca3631 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-separate/Lib_1.scala @@ -0,0 +1,6 @@ +object Lib: + extension [A](xs: Seq[A]) + def mapp[B](f: A => B): Seq[B] = + xs.map(f.asInstanceOf[A -> B]) + + diff --git a/tests/pos-custom-args/captures/capt-separate/Test_2.scala b/tests/pos-custom-args/captures/capt-separate/Test_2.scala new file mode 100644 index 000000000000..1a3b5bbd8571 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-separate/Test_2.scala @@ -0,0 +1,15 @@ +import language.experimental.saferExceptions +import Lib.* + +class LimitExceeded extends Exception + +val limit = 10e9 + +def f(x: Double): Double throws LimitExceeded = + if x < limit then x * x else throw LimitExceeded() + +@main def test(xs: Double*) = + try println(xs.mapp(f).sum) + catch case ex: LimitExceeded => println("too large") + + diff --git a/tests/pos-custom-args/captures/saferExceptions.scala b/tests/pos-custom-args/captures/saferExceptions.scala new file mode 100644 index 000000000000..47793c7450c8 --- /dev/null +++ b/tests/pos-custom-args/captures/saferExceptions.scala @@ -0,0 +1,18 @@ +import language.experimental.saferExceptions + +class LimitExceeded extends Exception + +val limit = 10e9 + +extension [A](xs: Seq[A]) + def mapp[B](f: A => B): Seq[B] = + xs.map(f.asInstanceOf[A -> B]) + +def f(x: Double): Double throws LimitExceeded = + if x < limit then x * x else throw LimitExceeded() + +@main def test(xs: Double*) = + try println(xs.mapp(f).sum + xs.map(f).sum) + catch case ex: LimitExceeded => println("too large") + + From 18dd570640115a75daef6d765cb4d339b3327d79 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Fri, 28 Jan 2022 23:00:32 +0100 Subject: [PATCH 21/24] New scheme to reject root captures Reject root captures by considering unbox operations. This allows us to ignore root captures buried under type applications. --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 26 ++++++- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 10 +++ .../dotty/tools/dotc/transform/Recheck.scala | 1 + .../tools/dotc/typer/CheckCaptures.scala | 77 +++++-------------- .../neg-custom-args/captures/capt-test.scala | 4 +- tests/neg-custom-args/captures/real-try.check | 26 +++++-- tests/neg-custom-args/captures/real-try.scala | 16 ++++ tests/neg-custom-args/captures/try.check | 46 +++++++---- tests/neg-custom-args/captures/try.scala | 12 +-- tests/neg-custom-args/captures/try3.scala | 4 +- tests/neg-custom-args/captures/vars.check | 27 +++---- tests/neg-custom-args/captures/vars.scala | 3 +- 12 files changed, 147 insertions(+), 105 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index aa7efe2a04cf..23cb802356fc 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -3,7 +3,7 @@ package dotc package cc import core.* -import Types.*, Symbols.*, Contexts.*, Annotations.* +import Types.*, Symbols.*, Contexts.*, Annotations.*, Flags.* import ast.{tpd, untpd} import Decorators.*, NameOps.* import config.Printers.capt @@ -85,3 +85,27 @@ extension (tp: Type) isImpure = true).appliedTo(args) case _ => tp + +extension (sym: Symbol) + + /** Does this symbol allow results carrying the universal capability? + * Currently this is true only for function type applies (since their + * results are unboxed) and `erasedValue` since this function is magic in + * that is allows to conjure global capabilies from nothing (aside: can we find a + * more controlled way to achieve this?). + * But it could be generalized to other functions that so that they can take capability + * classes as arguments. + */ + def allowsRootCapture(using Context): Boolean = + sym == defn.Compiletime_erasedValue + || defn.isFunctionClass(sym.maybeOwner) + + def unboxesResult(using Context): Boolean = + def containsEnclTypeParam(tp: Type): Boolean = tp.strippedDealias match + case tp @ TypeRef(pre: ThisType, _) => tp.symbol.is(Param) + case tp: TypeParamRef => true + case tp: AndOrType => containsEnclTypeParam(tp.tp1) || containsEnclTypeParam(tp.tp2) + case tp: RefinedType => containsEnclTypeParam(tp.parent) || containsEnclTypeParam(tp.refinedInfo) + case _ => false + containsEnclTypeParam(sym.info.finalResultType) + && !sym.allowsRootCapture diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index b8daef92beef..a987b8788dd1 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -172,6 +172,10 @@ sealed abstract class CaptureSet extends Showable: def - (ref: CaptureRef)(using Context): CaptureSet = this -- ref.singletonCaptureSet + def disallowRootCapability(handler: () => Unit)(using Context): this.type = + if isUniversal then handler() + this + def filter(p: CaptureRef => Boolean)(using Context): CaptureSet = if this.isConst then val elems1 = elems.filter(p) @@ -276,6 +280,7 @@ object CaptureSet: var deps: Deps = emptySet def isConst = isSolved def isAlwaysEmpty = false + var addRootHandler: () => Unit = () => () private def recordElemsState()(using VarState): Boolean = varState.getElems(this) match @@ -296,6 +301,7 @@ object CaptureSet: def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = if !isConst && recordElemsState() then elems ++= newElems + if isUniversal then addRootHandler() // assert(id != 2 || elems.size != 2, this) (CompareResult.OK /: deps) { (r, dep) => r.andAlso(dep.tryInclude(newElems, this)) @@ -312,6 +318,10 @@ object CaptureSet: else CompareResult.fail(this) + override def disallowRootCapability(handler: () => Unit)(using Context): this.type = + addRootHandler = handler + super.disallowRootCapability(handler) + private var computingApprox = false final def upperApprox(origin: CaptureSet)(using Context): CaptureSet = diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 33de86d091ef..f9af3c38de8a 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -329,6 +329,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: Alternative => recheckAlternative(tree, pt) case tree: PackageDef => recheckPackageDef(tree) case tree: Thicket => defn.NothingType + case tree: Import => defn.NothingType tree match case tree: NameTree => recheckNamed(tree, pt) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index c4e62e66ac75..994df9a382c4 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -76,16 +76,6 @@ object CheckCaptures: if remaining.accountsFor(firstRef) then report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos) - /** Does this function allow type arguments carrying the universal capability? - * Currently this is true only for `erasedValue` since this function is magic in - * that is allows to conjure global capabilies from nothing (aside: can we find a - * more controlled way to achieve this?). - * But it could be generalized to other functions that so that they can take capability - * classes as arguments. - */ - private def allowUniversalArguments(fn: Tree)(using Context): Boolean = - fn.symbol == defn.Compiletime_erasedValue - class CheckCaptures extends Recheck: thisPhase => @@ -309,6 +299,26 @@ class CheckCaptures extends Recheck: includeBoxedCaptures(res, tree.srcPos) res + override def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = + val typeToCheck = tree match + case _: Ident | _: Select | _: Apply | _: TypeApply if tree.symbol.unboxesResult => + tpe + case _: Try => + tpe + case ValDef(_, tpt, _) if tree.symbol.is(Mutable) => + tree.symbol.info + case _ => + NoType + if typeToCheck.exists then + typeToCheck.widenDealias match + case wtp @ CapturingType(parent, refs, _) => + refs.disallowRootCapability { () => + val kind = if tree.isInstanceOf[ValDef] then "mutable variable" else "expression" + report.error(em"the $kind's type $wtp is not allowed to capture the root capability `*`", tree.srcPos) + } + case _ => + super.recheckFinish(tpe, tree, pt) + override def checkUnit(unit: CompilationUnit)(using Context): Unit = Setup(preRecheckPhase, thisPhase, recheckDef) .traverse(ctx.compilationUnit.tpdTree) @@ -319,45 +329,6 @@ class CheckCaptures extends Recheck: show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing } - def checkNotGlobal(tree: Tree, tp: Type, isVar: Boolean, allArgs: Tree*)(using Context): Unit = - for ref <- tp.captureSet.elems do - val isGlobal = ref match - case ref: TermRef => ref.isRootCapability - case _ => false - if isGlobal then - val what = if ref.isRootCapability then "universal" else "global" - val notAllowed = i" is not allowed to capture the $what capability $ref" - def msg = - if allArgs.isEmpty then - i"${if isVar then "type of mutable variable" else "result type"} ${tree.knownType}$notAllowed" - else tree match - case tree: InferredTypeTree => - i"""inferred type argument ${tree.knownType}$notAllowed - | - |The inferred arguments are: [${allArgs.map(_.knownType)}%, %]""" - case _ => s"type argument$notAllowed" - report.error(msg, tree.srcPos) - - def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = - tree match - case LambdaTypeTree(_, restpt) => - checkNotGlobal(restpt, allArgs*) - case _ => - checkNotGlobal(tree, tree.knownType, isVar = false, allArgs*) - - def checkNotGlobalDeep(tree: Tree)(using Context): Unit = - val checker = new TypeTraverser: - def traverse(tp: Type): Unit = tp match - case tp: TypeRef => - tp.info match - case TypeBounds(_, hi) => traverse(hi) - case _ => - case tp: TermRef => - case _ => - checkNotGlobal(tree, tp, isVar = true) - traverseChildren(tp) - checker.traverse(tree.knownType) - object PostCheck extends TreeTraverser: def traverse(tree: Tree)(using Context) = trace{i"post check $tree"} { tree match @@ -370,10 +341,6 @@ class CheckCaptures extends Recheck: checkWellformedPost(annot.tree) case _ => } - case tree1 @ TypeApply(fn, args) if !allowUniversalArguments(fn) => - for arg <- args do - //println(i"checking $arg in $tree: ${tree.knownType.captureSet}") - checkNotGlobal(arg, args*) case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] => val sym = t.symbol val isLocal = @@ -396,10 +363,6 @@ class CheckCaptures extends Recheck: |The type needs to be declared explicitly.""", t.srcPos) case _ => inferred.foreachPart(checkPure, StopAt.Static) - case t: ValDef if t.symbol.is(Mutable) => - checkNotGlobalDeep(t.tpt) - case t: Try => - checkNotGlobal(t) case _ => traverseChildren(tree) } diff --git a/tests/neg-custom-args/captures/capt-test.scala b/tests/neg-custom-args/captures/capt-test.scala index 0c536a280f5c..0face680a285 100644 --- a/tests/neg-custom-args/captures/capt-test.scala +++ b/tests/neg-custom-args/captures/capt-test.scala @@ -19,8 +19,8 @@ def handle[E <: Exception, R <: Top](op: (CanThrow[E]) => R)(handler: E => R): R catch case ex: E => handler(ex) def test: Unit = - val b = handle[Exception, () => Nothing] { // error + val b = handle[Exception, () => Nothing] { (x: CanThrow[Exception]) => () => raise(new Exception)(using x) - } { + } { // error (ex: Exception) => ??? } diff --git a/tests/neg-custom-args/captures/real-try.check b/tests/neg-custom-args/captures/real-try.check index 11a6fdfd50dd..95531857712e 100644 --- a/tests/neg-custom-args/captures/real-try.check +++ b/tests/neg-custom-args/captures/real-try.check @@ -1,8 +1,20 @@ --- Error: tests/neg-custom-args/captures/real-try.scala:10:2 ----------------------------------------------------------- -10 | try // error +-- Error: tests/neg-custom-args/captures/real-try.scala:12:2 ----------------------------------------------------------- +12 | try // error | ^ - | result type {*} () -> Unit is not allowed to capture the universal capability *.type -11 | () => foo(1) -12 | catch -13 | case _: Ex1 => ??? -14 | case _: Ex2 => ??? + | the expression's type {*} () -> Unit is not allowed to capture the root capability `*` +13 | () => foo(1) +14 | catch +15 | case _: Ex1 => ??? +16 | case _: Ex2 => ??? +-- Error: tests/neg-custom-args/captures/real-try.scala:18:2 ----------------------------------------------------------- +18 | try // error + | ^ + | the expression's type {*} () -> ? Cell[Unit] is not allowed to capture the root capability `*` +19 | () => Cell(foo(1)) +20 | catch +21 | case _: Ex1 => ??? +22 | case _: Ex2 => ??? +-- Error: tests/neg-custom-args/captures/real-try.scala:30:4 ----------------------------------------------------------- +30 | b.x // error + | ^^^ + | the expression's type box {*} () -> Unit is not allowed to capture the root capability `*` diff --git a/tests/neg-custom-args/captures/real-try.scala b/tests/neg-custom-args/captures/real-try.scala index 9a8ccd694dc9..94e1eafd9af2 100644 --- a/tests/neg-custom-args/captures/real-try.scala +++ b/tests/neg-custom-args/captures/real-try.scala @@ -6,9 +6,25 @@ class Ex2 extends Exception("Ex2") def foo(i: Int): (CanThrow[Ex1], CanThrow[Ex2]) ?-> Unit = if i > 0 then throw new Ex1 else throw new Ex2 +class Cell[+T](val x: T) + def test() = try // error () => foo(1) catch case _: Ex1 => ??? case _: Ex2 => ??? + + try // error + () => Cell(foo(1)) + catch + case _: Ex1 => ??? + case _: Ex2 => ??? + + val b = try // ok here, but error on use + Cell(() => foo(1))//: Cell[box {ev} () => Unit] <: Cell[box {*} () => Unit] + catch + case _: Ex1 => ??? + case _: Ex2 => ??? + + b.x // error diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check index a2fe96016b80..7dbccc469089 100644 --- a/tests/neg-custom-args/captures/try.check +++ b/tests/neg-custom-args/captures/try.check @@ -1,3 +1,11 @@ +-- Error: tests/neg-custom-args/captures/try.scala:24:3 ---------------------------------------------------------------- +22 | val a = handle[Exception, CanThrow[Exception]] { +23 | (x: CanThrow[Exception]) => x +24 | }{ // error + | ^ + | the expression's type {*} CT[Exception] is not allowed to capture the root capability `*` +25 | (ex: Exception) => ??? +26 | } -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------ 28 | val b = handle[Exception, () -> Nothing] { // error | ^ @@ -7,19 +15,25 @@ 30 | } { longer explanation available when compiling with `-explain` --- Error: tests/neg-custom-args/captures/try.scala:22:28 --------------------------------------------------------------- -22 | val a = handle[Exception, CanThrow[Exception]] { // error - | ^^^^^^^^^^^^^^^^^^^ - | type argument is not allowed to capture the universal capability (* : Any) --- Error: tests/neg-custom-args/captures/try.scala:34:11 --------------------------------------------------------------- -34 | val xx = handle { // error - | ^^^^^^ - | inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any) - | - | The inferred arguments are: [? Exception, {*} () -> Int] --- Error: tests/neg-custom-args/captures/try.scala:46:13 --------------------------------------------------------------- -46 |val global = handle { // error - | ^^^^^^ - | inferred type argument {*} () -> Int is not allowed to capture the universal capability (* : Any) - | - | The inferred arguments are: [? Exception, {*} () -> Int] +-- Error: tests/neg-custom-args/captures/try.scala:39:4 ---------------------------------------------------------------- +34 | val xx = handle { +35 | (x: CanThrow[Exception]) => +36 | () => +37 | raise(new Exception)(using x) +38 | 22 +39 | } { // error + | ^ + | the expression's type {*} () -> Int is not allowed to capture the root capability `*` +40 | (ex: Exception) => () => 22 +41 | } +-- Error: tests/neg-custom-args/captures/try.scala:51:2 ---------------------------------------------------------------- +46 |val global = handle { +47 | (x: CanThrow[Exception]) => +48 | () => +49 | raise(new Exception)(using x) +50 | 22 +51 |} { // error + | ^ + | the expression's type {*} () -> Int is not allowed to capture the root capability `*` +52 | (ex: Exception) => () => 22 +53 |} diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala index b128f82a2a3c..c76da6641780 100644 --- a/tests/neg-custom-args/captures/try.scala +++ b/tests/neg-custom-args/captures/try.scala @@ -19,9 +19,9 @@ def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = catch case ex: E => handler(ex) def test = - val a = handle[Exception, CanThrow[Exception]] { // error + val a = handle[Exception, CanThrow[Exception]] { (x: CanThrow[Exception]) => x - }{ + }{ // error (ex: Exception) => ??? } @@ -31,23 +31,23 @@ def test = (ex: Exception) => ??? } - val xx = handle { // error + val xx = handle { (x: CanThrow[Exception]) => () => raise(new Exception)(using x) 22 - } { + } { // error (ex: Exception) => () => 22 } val yy = xx :: Nil yy // OK -val global = handle { // error +val global = handle { (x: CanThrow[Exception]) => () => raise(new Exception)(using x) 22 -} { +} { // error (ex: Exception) => () => 22 } \ No newline at end of file diff --git a/tests/neg-custom-args/captures/try3.scala b/tests/neg-custom-args/captures/try3.scala index 4fbb980b9e03..8c5bc18bf3be 100644 --- a/tests/neg-custom-args/captures/try3.scala +++ b/tests/neg-custom-args/captures/try3.scala @@ -14,12 +14,12 @@ def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = @main def Test: Int = def f(a: Boolean) = - handle { // error + handle { if !a then raise(IOException()) (b: Boolean) => if !b then raise(IOException()) 0 - } { + } { // error ex => (b: Boolean) => -1 } val g = f(true) diff --git a/tests/neg-custom-args/captures/vars.check b/tests/neg-custom-args/captures/vars.check index 0df38b918862..6a036e49ede2 100644 --- a/tests/neg-custom-args/captures/vars.check +++ b/tests/neg-custom-args/captures/vars.check @@ -5,17 +5,18 @@ | Required: () -> Unit longer explanation available when compiling with `-explain` --- Error: tests/neg-custom-args/captures/vars.scala:13:16 -------------------------------------------------------------- +-- Error: tests/neg-custom-args/captures/vars.scala:13:6 --------------------------------------------------------------- 13 | var a: String => String = f // error - | ^^^^^^^^^^^^^^^^ - | type of mutable variable String => String is not allowed to capture the universal capability (* : Any) --- Error: tests/neg-custom-args/captures/vars.scala:14:9 --------------------------------------------------------------- -14 | var b: List[String => String] = Nil // error - | ^^^^^^^^^^^^^^^^^^^^^^ - | type of mutable variable List[String => String] is not allowed to capture the universal capability (* : Any) --- Error: tests/neg-custom-args/captures/vars.scala:29:2 --------------------------------------------------------------- -29 | local { cap3 => // error - | ^^^^^ - |inferred type argument {*} (x$0: ? String) -> ? String is not allowed to capture the universal capability (* : Any) - | - |The inferred arguments are: [{*} (x$0: ? String) -> ? String] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | the mutable variable's type {*} String -> String is not allowed to capture the root capability `*` +-- Error: tests/neg-custom-args/captures/vars.scala:15:4 --------------------------------------------------------------- +15 | b.head // error + | ^^^^^^ + | the expression's type {*} String -> String is not allowed to capture the root capability `*` +-- Error: tests/neg-custom-args/captures/vars.scala:30:8 --------------------------------------------------------------- +30 | local { cap3 => // error + | ^ + | the expression's type {*} (x$0: ? String) -> ? String is not allowed to capture the root capability `*` +31 | def g(x: String): String = if cap3 == cap3 then "" else "a" +32 | g +33 | } diff --git a/tests/neg-custom-args/captures/vars.scala b/tests/neg-custom-args/captures/vars.scala index e85bcbe2db04..5e413b7ea3fb 100644 --- a/tests/neg-custom-args/captures/vars.scala +++ b/tests/neg-custom-args/captures/vars.scala @@ -11,7 +11,8 @@ def test(cap1: Cap, cap2: Cap) = val z2c: () -> Unit = z2 // error var a: String => String = f // error - var b: List[String => String] = Nil // error + var b: List[String => String] = Nil // was error, now OK + b.head // error def scope = val cap3: Cap = CC() From 4cc8ff0af8631acf56f894db7ffe079e4f7269d6 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 30 Jan 2022 18:44:48 +0100 Subject: [PATCH 22/24] Three changes to typing rules The following two rules replace #13657: 1. Exploit capture monotonicity in the apply rule, as discussed in #14387. 2. A rule to make typing nested classes more flexible as discussed in #14390. There's also a bug fix where we now enforce a previously missing subcapturing relationship between the capture set of parent of a class and the capture set of the class itself. Clearly a class captures all variables captured by one of its parent classes. --- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 2 +- .../dotty/tools/dotc/transform/Recheck.scala | 38 ++++----- .../tools/dotc/typer/CheckCaptures.scala | 82 +++++++++++++++---- tests/neg-custom-args/captures/lazylist.check | 4 +- .../neg-custom-args/captures/lazylists1.check | 10 +-- .../neg-custom-args/captures/lazylists1.scala | 2 +- .../neg-custom-args/captures/lazylists2.check | 32 +++++--- .../neg-custom-args/captures/lazylists2.scala | 8 +- .../captures/lazylists-exceptions.scala | 68 +++++++++++++++ .../pos-custom-args/captures/lazylists.scala | 1 - .../pos-custom-args/captures/lazylists1.scala | 17 ++-- 11 files changed, 198 insertions(+), 66 deletions(-) create mode 100644 tests/pos-custom-args/captures/lazylists-exceptions.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index a987b8788dd1..bf1306024e5f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -104,7 +104,7 @@ sealed abstract class CaptureSet extends Showable: extension (x: CaptureRef) private def subsumes(y: CaptureRef) = (x eq y) || y.match - case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ? + case y: TermRef => y.prefix eq x case _ => false /** {x} <:< this where <:< is subcapturing, but treating all variables diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index f9af3c38de8a..d100bdd36f21 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -126,15 +126,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: bindType def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = - if !tree.rhs.isEmpty then recheckRHS(tree.rhs, sym.info, sym) + if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - inContext(rhsCtx) { recheckRHS(tree.rhs, recheck(tree.tpt), sym) } - - def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = - recheck(tree, pt) + inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) @@ -358,21 +355,22 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: // Don't report closure nodes, since their span is a point; wait instead // for enclosing block to preduce an error case _ => - val actual = tpe.widenExpr - val expected = pt.widenExpr - //println(i"check conforms $actual <:< $expected") - val isCompatible = - actual <:< expected - || expected.isRepeatedParam - && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) - if !isCompatible then - recheckr.println(i"conforms failed for ${tree}: $tpe vs $expected") - err.typeMismatch(tree.withType(tpe), expected) - else if debugSuccesses then - tree match - case _: Ident => - println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") - case _ => + checkConformsExpr(tpe, tpe.widenExpr, pt.widenExpr, tree) + + def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit = + //println(i"check conforms $actual <:< $expected") + val isCompatible = + actual <:< expected + || expected.isRepeatedParam + && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) + if !isCompatible then + recheckr.println(i"conforms failed for ${tree}: $original vs $expected") + err.typeMismatch(tree.withType(original), expected) + else if debugSuccesses then + tree match + case _: Ident => + println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") + case _ => def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 994df9a382c4..2e81576eb317 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -16,6 +16,7 @@ import transform.Recheck import Recheck.* import scala.collection.mutable import CaptureSet.withCaptureSetsExplained +import StdNames.nme import reporting.trace object CheckCaptures: @@ -213,22 +214,6 @@ class CheckCaptures extends Recheck: interpolateVarsIn(tree.tpt) curEnv = saved - override def recheckRHS(tree: Tree, pt: Type, sym: Symbol)(using Context): Type = - val pt1 = pt match - case CapturingType(core, refs, _) - if sym.owner.isClass && !sym.owner.isExtensibleClass - && refs.elems.contains(sym.owner.thisType) => - val paramCaptures = - sym.paramSymss.flatten.foldLeft(CaptureSet.empty) { (cs, p) => - val pcs = p.info.captureSet - (cs ++ (if pcs.isConst then pcs else CaptureSet.universal)).asConst - } - val declaredCaptures = sym.owner.asClass.givenSelfType.captureSet - pt.derivedCapturingType(core, refs ++ (declaredCaptures -- paramCaptures)) - case _ => - pt - recheck(tree, pt1) - override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = for param <- cls.paramGetters do if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then @@ -237,6 +222,8 @@ class CheckCaptures extends Recheck: param.srcPos) val saved = curEnv val localSet = capturedVars(cls) + for parent <- impl.parents do + checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos) if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv) try super.recheckClassDef(tree, impl, cls) finally curEnv = saved @@ -289,9 +276,34 @@ class CheckCaptures extends Recheck: finally curEnv = curEnv.outer recheckFinish(result, arg, pt) + /** A specialized implementation of the apply rule from https://github.com/lampepfl/dotty/discussions/14387: + * + * E |- f: Cf (Ra -> Cr Rr) + * E |- a: Ra + * ------------------------ + * E |- f a: Cr /\ {f} Rr + * + * Specialized for the case where `f` is a tracked and the arguments are pure. + * This replaces the previous rule #13657 while still allowing the code in pos/lazylists1.scala. + * We could consider generalizing to the case where the function arguments have non-empty + * capture sets as suggested in #14387, but that would make capture set computations more complex, + * so we should also evaluate the performance impact. + */ override def recheckApply(tree: Apply, pt: Type)(using Context): Type = includeCallCaptures(tree.symbol, tree.srcPos) - super.recheckApply(tree, pt) + super.recheckApply(tree, pt) match + case tp @ CapturingType(tp1, refs, kind) => + tree.fun match + case Select(qual, nme.apply) + if defn.isFunctionType(qual.tpe.widen) => + qual.tpe match + case ref: CaptureRef + if ref.isTracked && tree.args.forall(_.tpe.captureSet.isAlwaysEmpty) => + tp.derivedCapturingType(tp1, refs ** ref.singletonCaptureSet) + .showing(i"narrow $tree: $tp --> $result", capt) + case _ => tp + case _ => tp + case tp => tp override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = val res = super.recheck(tree, pt) @@ -319,6 +331,42 @@ class CheckCaptures extends Recheck: case _ => super.recheckFinish(tpe, tree, pt) + /** This method implements the rule outlined in #14390: + * When checking an expression `e: Ca Ta` against an expected type `Cx Tx` + * where the capture set of `Cx` contains this and any method inside the class + * `Cls` of `this` that contains `e` has only pure parameters, drop from `Ca` + * all references to variables or this references outside `Cls`. These are all + * accessed through this, so are already accounted for by `Cx`. + */ + override def checkConformsExpr(original: Type, actual: Type, expected: Type, tree: Tree)(using Context): Unit = + def isPure(info: Type): Boolean = info match + case info: PolyType => isPure(info.resType) + case info: MethodType => info.paramInfos.forall(_.captureSet.isAlwaysEmpty) && isPure(info.resType) + case _ => true + def isPureContext(owner: Symbol, limit: Symbol): Boolean = + if owner == limit then true + else if !owner.exists then false + else isPure(owner.info) && isPureContext(owner.owner, limit) + val actual1 = (expected, actual.widen) match + case (CapturingType(ecore, erefs, _), actualw @ CapturingType(acore, arefs, _)) => + val arefs1 = (arefs /: erefs.elems) { (arefs1, eref) => + eref match + case eref: ThisType if isPureContext(ctx.owner, eref.cls) => + arefs1.filter { + case aref1: TermRef => !eref.cls.isContainedIn(aref1.symbol.owner) + case aref1: ThisType => !eref.cls.isContainedIn(aref1.cls) + case _ => true + } + case _ => + arefs1 + } + if arefs1 eq arefs then actual + else actualw.derivedCapturingType(acore, arefs1) + .showing(i"healing $actual --> $result", capt) + case _ => + actual + super.checkConformsExpr(original, actual1, expected, tree) + override def checkUnit(unit: CompilationUnit)(using Context): Unit = Setup(preRecheckPhase, thisPhase, recheckDef) .traverse(ctx.compilationUnit.tpdTree) diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check index 31624e437928..bdbef10de0d6 100644 --- a/tests/neg-custom-args/captures/lazylist.check +++ b/tests/neg-custom-args/captures/lazylist.check @@ -37,6 +37,6 @@ longer explanation available when compiling with `-explain` 17 | def tail = xs() // error: cannot have an inferred type | ^^^^^^^^^^^^^^^ | Non-local method tail cannot have an inferred result type - | {*} lazylists.LazyList[T] - | with non-empty capture set {*}. + | {LazyCons.this.xs} lazylists.LazyList[T] + | with non-empty capture set {LazyCons.this.xs}. | The type needs to be declared explicitly. diff --git a/tests/neg-custom-args/captures/lazylists1.check b/tests/neg-custom-args/captures/lazylists1.check index 29291c8044c0..1d23de2ff134 100644 --- a/tests/neg-custom-args/captures/lazylists1.check +++ b/tests/neg-custom-args/captures/lazylists1.check @@ -1,7 +1,7 @@ --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:63 ----------------------------------- -25 | def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | Found: {xs, f} LazyList[A] - | Required: {Mapped.this, xs} LazyList[A] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists1.scala:25:66 ----------------------------------- +25 | def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {xs, f} LazyList[A] + | Required: {Mapped.this, f} LazyList[A] longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists1.scala b/tests/neg-custom-args/captures/lazylists1.scala index 4091ee2c62ae..c6475223b783 100644 --- a/tests/neg-custom-args/captures/lazylists1.scala +++ b/tests/neg-custom-args/captures/lazylists1.scala @@ -22,6 +22,6 @@ extension [A](xs: {*} LazyList[A]) def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK def drop(n: Int): {this} LazyList[B] = ??? : ({xs, f} LazyList[B]) // OK - def concat(other: {f} LazyList[A]): {this} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error + def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : ({xs, f} LazyList[A]) // error new Mapped diff --git a/tests/neg-custom-args/captures/lazylists2.check b/tests/neg-custom-args/captures/lazylists2.check index a8f145ecd9d7..8cb825f02537 100644 --- a/tests/neg-custom-args/captures/lazylists2.check +++ b/tests/neg-custom-args/captures/lazylists2.check @@ -29,17 +29,29 @@ longer explanation available when compiling with `-explain` 32 | new Mapped longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:41:48 ----------------------------------- -41 | def tail: {this} LazyList[B] = xs.tail.map(f) // error - | ^^^^^^^^^^^^^^ - | Found: {f} LazyList[B] - | Required: {xs} LazyList[B] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:36:4 ------------------------------------ +36 | final class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: {xs} LazyList[B] +37 | this: ({xs} Mapped) => +38 | def isEmpty = false +39 | def head: B = f(xs.head) +40 | def tail: {this} LazyList[B] = xs.tail.map(f) +41 | new Mapped longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:59:48 ----------------------------------- -59 | def tail: {this} LazyList[B] = xs.tail.map(f) // error - | ^^^^^^^^^^^^^^ - | Found: {f} LazyList[B] - | Required: {Mapped.this} LazyList[B] +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylists2.scala:54:4 ------------------------------------ +54 | class Mapped extends LazyList[B]: // error + | ^ + | Found: {f, xs} LazyList[B] + | Required: LazyList[B] +55 | this: ({xs, f} Mapped) => +56 | def isEmpty = false +57 | def head: B = f(xs.head) +58 | def tail: {this} LazyList[B] = xs.tail.map(f) +59 | class Mapped2 extends Mapped: +60 | this: Mapped => +61 | new Mapped2 longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazylists2.scala b/tests/neg-custom-args/captures/lazylists2.scala index b9ebb0a7a9f0..fc1ab768047a 100644 --- a/tests/neg-custom-args/captures/lazylists2.scala +++ b/tests/neg-custom-args/captures/lazylists2.scala @@ -33,12 +33,12 @@ extension [A](xs: {*} LazyList[A]) new Mapped def map3[B](f: A => B): {xs} LazyList[B] = - final class Mapped extends LazyList[B]: + final class Mapped extends LazyList[B]: // error this: ({xs} Mapped) => def isEmpty = false def head: B = f(xs.head) - def tail: {this} LazyList[B] = xs.tail.map(f) // error + def tail: {this} LazyList[B] = xs.tail.map(f) new Mapped def map4[B](f: A => B): {xs} LazyList[B] = @@ -51,12 +51,12 @@ extension [A](xs: {*} LazyList[A]) new Mapped def map5[B](f: A => B): LazyList[B] = - class Mapped extends LazyList[B]: + class Mapped extends LazyList[B]: // error this: ({xs, f} Mapped) => def isEmpty = false def head: B = f(xs.head) - def tail: {this} LazyList[B] = xs.tail.map(f) // error + def tail: {this} LazyList[B] = xs.tail.map(f) class Mapped2 extends Mapped: this: Mapped => new Mapped2 diff --git a/tests/pos-custom-args/captures/lazylists-exceptions.scala b/tests/pos-custom-args/captures/lazylists-exceptions.scala new file mode 100644 index 000000000000..6b2378906b8e --- /dev/null +++ b/tests/pos-custom-args/captures/lazylists-exceptions.scala @@ -0,0 +1,68 @@ +import language.experimental.saferExceptions +import annotation.unchecked.uncheckedVariance + +trait LazyList[+A]: + this: {*} LazyList[A] => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]: + this: {*} LazyList[T] => + + var forced = false + var cache: {this} LazyList[T @uncheckedVariance] = compiletime.uninitialized + + private def force = + if !forced then + cache = xs() + forced = true + cache + + def isEmpty = false + def head = x + def tail: {this} LazyList[T] = force +end LazyCons + +extension [A](xs: {*} LazyList[A]) + def map[B](f: A => B): {xs, f} LazyList[B] = + if xs.isEmpty then LazyNil + else LazyCons(f(xs.head), () => xs.tail.map(f)) + + def filter(p: A => Boolean): {xs, p} LazyList[A] = + if xs.isEmpty then LazyNil + else if p(xs.head) then LazyCons(xs.head, () => xs.tail.filter(p)) + else xs.tail.filter(p) + + def concat(ys: {*} LazyList[A]): {xs, ys} LazyList[A] = + if xs.isEmpty then ys + else LazyCons(xs.head, () => xs.tail.concat(ys)) +end extension + +class Ex1 extends Exception +class Ex2 extends Exception + +def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) = + val xs = LazyCons(1, () => LazyNil) + + def f(x: Int): Int throws Ex1 = + if x < 0 then throw Ex1() + x * x + + def g(x: Int): Int throws Ex1 = + if x < 0 then throw Ex1() + x * x + + def x1 = xs.map(f) + def x1c: {cap1} LazyList[Int] = x1 + + def x2 = x1.concat(xs.map(g).filter(_ > 0)) + def x2c: {cap1, cap2} LazyList[Int] = x2 + + diff --git a/tests/pos-custom-args/captures/lazylists.scala b/tests/pos-custom-args/captures/lazylists.scala index c566bea8dd64..fd130c87cdea 100644 --- a/tests/pos-custom-args/captures/lazylists.scala +++ b/tests/pos-custom-args/captures/lazylists.scala @@ -21,7 +21,6 @@ extension [A](xs: {*} LazyList[A]) def isEmpty = false def head: B = f(xs.head) def tail: {this} LazyList[B] = xs.tail.map(f) // OK - def concat(other: {f} LazyList[A]): {this, f} LazyList[A] = ??? : {xs, f} LazyList[A] // OK if xs.isEmpty then LazyNil else new Mapped diff --git a/tests/pos-custom-args/captures/lazylists1.scala b/tests/pos-custom-args/captures/lazylists1.scala index 2dbb5ac232e2..c203450499e7 100644 --- a/tests/pos-custom-args/captures/lazylists1.scala +++ b/tests/pos-custom-args/captures/lazylists1.scala @@ -7,29 +7,36 @@ trait LazyList[+A]: def isEmpty: Boolean def head: A def tail: {this} LazyList[A] + def concat[B >: A](other: {*} LazyList[B]): {this, other} LazyList[B] object LazyNil extends LazyList[Nothing]: def isEmpty: Boolean = true def head = ??? def tail = ??? + def concat[B](other: {*} LazyList[B]): {other} LazyList[B] = other -final class LazyCons[+T](val x: T, val xs: Int => {*} LazyList[T]) extends LazyList[T]: - this: {*} LazyList[T] => +final class LazyCons[+A](val x: A, val xs: () => {*} LazyList[A]) extends LazyList[A]: + this: {*} LazyList[A] => def isEmpty = false def head = x - def tail: {this} LazyList[T] = xs(0) + def tail: {this} LazyList[A] = xs() + def concat[B >: A](other: {*} LazyList[B]): {this, other} LazyList[B] = + LazyCons(head, () => tail.concat(other)) extension [A](xs: {*} LazyList[A]) def map[B](f: A => B): {xs, f} LazyList[B] = if xs.isEmpty then LazyNil - else LazyCons(f(xs.head), x => xs.tail.map(f)) + else LazyCons(f(xs.head), () => xs.tail.map(f)) def test(cap1: Cap, cap2: Cap) = def f(x: String): String = if cap1 == cap1 then "" else "a" def g(x: String): String = if cap2 == cap2 then "" else "a" - val xs = new LazyCons("", x => if f("") == f("") then LazyNil else LazyNil) + val xs = new LazyCons("", () => if f("") == f("") then LazyNil else LazyNil) val xsc: {cap1} LazyList[String] = xs val ys = xs.map(g) val ysc: {cap1, cap2} LazyList[String] = ys + val zs = new LazyCons("", () => if g("") == g("") then LazyNil else LazyNil) + val as = xs.concat(zs) + val asc: {xs, zs} LazyList[String] = as From fedd8f2b147ec223628098a19b5e9cb90819c6c7 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 1 Feb 2022 11:14:42 +0100 Subject: [PATCH 23/24] Change result type of by-name closures Make it the formal type rather than the actual one. This avoids messing up capture annotations. --- .../tools/dotc/transform/ElimByName.scala | 9 +++------ tests/neg-custom-args/captures/byname.check | 18 +++++++++--------- tests/neg-custom-args/captures/byname.scala | 5 ++++- .../captures/lazylists-exceptions.scala | 13 ++++++++++--- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/ElimByName.scala b/compiler/src/dotty/tools/dotc/transform/ElimByName.scala index b3d5ab1da4b4..38ef9428273b 100644 --- a/compiler/src/dotty/tools/dotc/transform/ElimByName.scala +++ b/compiler/src/dotty/tools/dotc/transform/ElimByName.scala @@ -92,6 +92,7 @@ class ElimByName extends MiniPhase, InfoTransformer: sym.is(Method) || exprBecomesFunction(sym) def byNameClosure(arg: Tree, argType: Type)(using Context): Tree = + report.log(i"creating by name closure for $argType") val meth = newAnonFun(ctx.owner, MethodType(Nil, argType), coord = arg.span) Closure(meth, _ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase), @@ -135,12 +136,8 @@ class ElimByName extends MiniPhase, InfoTransformer: if isByNameRef(qual) && (isPureExpr(qual) || qual.symbol.isAllOf(InlineParam)) => qual case _ => - if isByNameRef(arg) || arg.symbol.name.is(SuperArgName) - then arg - else - var argType = arg.tpe.widenIfUnstable - if argType.isBottomType then argType = formalResult - byNameClosure(arg, argType) + if isByNameRef(arg) || arg.symbol.name.is(SuperArgName) then arg + else byNameClosure(arg, formalResult) case _ => arg diff --git a/tests/neg-custom-args/captures/byname.check b/tests/neg-custom-args/captures/byname.check index d8d5d689efae..3321da3c17db 100644 --- a/tests/neg-custom-args/captures/byname.check +++ b/tests/neg-custom-args/captures/byname.check @@ -1,18 +1,18 @@ --- Warning: tests/neg-custom-args/captures/byname.scala:14:18 ---------------------------------------------------------- -14 | def h(x: {cap1} -> I) = x // warning +-- Warning: tests/neg-custom-args/captures/byname.scala:17:18 ---------------------------------------------------------- +17 | def h(x: {cap1} -> I) = x // warning | ^ | Style: by-name `->` should immediately follow closing `}` of capture set | to avoid confusion with function type. | That is, `{c}-> T` instead of `{c} -> T`. --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:7:5 ----------------------------------------- -7 | h(f()) // error - | ^^^ - | Found: {cap2} (x$0: Int) -> Int - | Required: Int -> Int +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:10:6 ---------------------------------------- +10 | h(f2()) // error + | ^^^^ + | Found: {cap1} (x$0: Int) -> Int + | Required: {cap2} Int -> Int longer explanation available when compiling with `-explain` --- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:16:5 ---------------------------------------- -16 | h(g()) // error +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:19:5 ---------------------------------------- +19 | h(g()) // error | ^^^ | Found: {cap2} () ?-> I | Required: {cap1} () ?-> I diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala index 0afca7e261e2..1838647f2899 100644 --- a/tests/neg-custom-args/captures/byname.scala +++ b/tests/neg-custom-args/captures/byname.scala @@ -3,8 +3,11 @@ def test(cap1: Cap, cap2: Cap) = def f() = if cap1 == cap1 then g else g def g(x: Int) = if cap2 == cap2 then 1 else x + def g2(x: Int) = if cap1 == cap1 then 1 else x + def f2() = if cap1 == cap1 then g2 else g2 def h(ff: => {cap2} Int -> Int) = ff - h(f()) // error + h(f()) // ok + h(f2()) // error class I diff --git a/tests/pos-custom-args/captures/lazylists-exceptions.scala b/tests/pos-custom-args/captures/lazylists-exceptions.scala index 6b2378906b8e..b9b303118358 100644 --- a/tests/pos-custom-args/captures/lazylists-exceptions.scala +++ b/tests/pos-custom-args/captures/lazylists-exceptions.scala @@ -37,19 +37,26 @@ extension [A](xs: {*} LazyList[A]) def filter(p: A => Boolean): {xs, p} LazyList[A] = if xs.isEmpty then LazyNil - else if p(xs.head) then LazyCons(xs.head, () => xs.tail.filter(p)) + else if p(xs.head) then lazyCons(xs.head, xs.tail.filter(p)) else xs.tail.filter(p) def concat(ys: {*} LazyList[A]): {xs, ys} LazyList[A] = if xs.isEmpty then ys - else LazyCons(xs.head, () => xs.tail.concat(ys)) + else xs.head #: xs.tail.concat(ys) end extension +extension [A](x: A) + def #:(xs1: => {*} LazyList[A]): {xs1} LazyList[A] = + LazyCons(x, () => xs1) + +def lazyCons[A](x: A, xs1: => {*} LazyList[A]): {xs1} LazyList[A] = + LazyCons(x, () => xs1) + class Ex1 extends Exception class Ex2 extends Exception def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) = - val xs = LazyCons(1, () => LazyNil) + val xs = 1 #: LazyNil def f(x: Int): Int throws Ex1 = if x < 0 then throw Ex1() From 0ddbbe820977e03407f1d2ab4352b39cf2307033 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 1 Feb 2022 14:58:19 +0100 Subject: [PATCH 24/24] Capture-check exception capabilties Required exeption capability references are now preserved beyond typer in type ascriptions, so that they can be checked for escapes. --- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../tools/dotc/typer/CheckCaptures.scala | 10 +++++ .../src/dotty/tools/dotc/typer/Checking.scala | 6 ++- .../src/dotty/tools/dotc/typer/Typer.scala | 13 ++++-- .../scala/internal/requiresCapability.scala | 8 ++++ .../captures/lazylists-exceptions.check | 9 ++++ .../captures/lazylists-exceptions.scala | 45 +++++++++++++++++++ tests/pos-custom-args/captures/i13816.scala | 4 +- .../captures/lazylists-exceptions.scala | 11 +++++ 9 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 library/src-bootstrapped/scala/internal/requiresCapability.scala create mode 100644 tests/neg-custom-args/captures/lazylists-exceptions.check create mode 100644 tests/neg-custom-args/captures/lazylists-exceptions.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 974fd3f2ab0d..ab0c28ffde02 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -968,6 +968,7 @@ class Definitions { @tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName") @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") @tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since") + @tu lazy val RequiresCapabilityAnnot: ClassSymbol = requiredClass("scala.annotation.internal.requiresCapability") @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") @tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName") diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala index 2e81576eb317..46ce6d037ff3 100644 --- a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -305,6 +305,16 @@ class CheckCaptures extends Recheck: case _ => tp case tp => tp + override def recheckTyped(tree: Typed)(using Context): Type = + tree.tpt.tpe match + case AnnotatedType(_, annot) if annot.symbol == defn.RequiresCapabilityAnnot => + annot.tree match + case Apply(_, cap :: Nil) => + markFree(cap.symbol, tree.srcPos) + case _ => + case _ => + super.recheckTyped(tree) + override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = val res = super.recheck(tree, pt) if tree.isTerm then diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index c46c6d7c06cd..e994b7ef434f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -1379,9 +1379,11 @@ trait Checking { val kind = if pattern then "pattern selector" else "value" report.warning(MatchableWarning(tp, pattern), pos) - def checkCanThrow(tp: Type, span: Span)(using Context): Unit = + def checkCanThrow(tp: Type, span: Span)(using Context): Tree = if Feature.enabled(Feature.saferExceptions) && tp.isCheckedException then ctx.typer.implicitArgTree(defn.CanThrowClass.typeRef.appliedTo(tp), span) + else + EmptyTree /** Check that catch can generate a good CanThrow exception */ def checkCatch(pat: Tree, guard: Tree)(using Context): Unit = pat match @@ -1409,7 +1411,7 @@ trait ReChecking extends Checking { override def checkAnnotApplicable(annot: Tree, sym: Symbol)(using Context): Boolean = true override def checkMatchable(tp: Type, pos: SrcPos, pattern: Boolean)(using Context): Unit = () override def checkNoModuleClash(sym: Symbol)(using Context) = () - override def checkCanThrow(tp: Type, span: Span)(using Context): Unit = () + override def checkCanThrow(tp: Type, span: Span)(using Context): Tree = EmptyTree override def checkCatch(pat: Tree, guard: Tree)(using Context): Unit = () } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index dfe5cbf40f13..081d3262e426 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1812,11 +1812,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer desugar.makeTryCase(handler1) :: Nil typedTry(untpd.Try(tree.expr, cases, tree.finalizer).withSpan(tree.span), pt) - def typedThrow(tree: untpd.Throw)(using Context): Tree = { + def typedThrow(tree: untpd.Throw)(using Context): Tree = val expr1 = typed(tree.expr, defn.ThrowableType) - checkCanThrow(expr1.tpe.widen, tree.span) - Throw(expr1).withSpan(tree.span) - } + val cap = checkCanThrow(expr1.tpe.widen, tree.span) + val res = Throw(expr1).withSpan(tree.span) + if cap.isEmpty || !ctx.settings.Ycc.value || ctx.isAfterTyper then res + else + Typed(res, + TypeTree( + AnnotatedType(res.tpe, + Annotation(defn.RequiresCapabilityAnnot, cap)))) def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = { val elemProto = pt.stripNull.elemType match { diff --git a/library/src-bootstrapped/scala/internal/requiresCapability.scala b/library/src-bootstrapped/scala/internal/requiresCapability.scala new file mode 100644 index 000000000000..371c44173f0b --- /dev/null +++ b/library/src-bootstrapped/scala/internal/requiresCapability.scala @@ -0,0 +1,8 @@ +package scala.annotation.internal + +import scala.annotation.StaticAnnotation + +/** An annotation to record a required capaility in the type of a throws + */ +class requiresCapability(capability: Any) extends StaticAnnotation + diff --git a/tests/neg-custom-args/captures/lazylists-exceptions.check b/tests/neg-custom-args/captures/lazylists-exceptions.check new file mode 100644 index 000000000000..cf55ff8ebd55 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists-exceptions.check @@ -0,0 +1,9 @@ +-- Error: tests/neg-custom-args/captures/lazylists-exceptions.scala:36:2 ----------------------------------------------- +36 | try // error + | ^ + | the expression's type {*} LazyList[Int] is not allowed to capture the root capability `*` +37 | tabulate(10) { i => +38 | if i > 9 then throw Ex1() +39 | i * i +40 | } +41 | catch case ex: Ex1 => LazyNil diff --git a/tests/neg-custom-args/captures/lazylists-exceptions.scala b/tests/neg-custom-args/captures/lazylists-exceptions.scala new file mode 100644 index 000000000000..6d325abcf936 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylists-exceptions.scala @@ -0,0 +1,45 @@ +import language.experimental.saferExceptions + +trait LazyList[+A]: + this: {*} LazyList[A] => + + def isEmpty: Boolean + def head: A + def tail: {this} LazyList[A] + +object LazyNil extends LazyList[Nothing]: + def isEmpty: Boolean = true + def head = ??? + def tail = ??? + +final class LazyCons[+T](val x: T, val xs: () => {*} LazyList[T]) extends LazyList[T]: + this: {*} LazyList[T] => + + def isEmpty = false + def head = x + def tail: {this} LazyList[T] = xs() +end LazyCons + +extension [A](x: A) + def #:(xs1: => {*} LazyList[A]): {xs1} LazyList[A] = + LazyCons(x, () => xs1) + +def tabulate[A](n: Int)(gen: Int => A) = + def recur(i: Int): {gen} LazyList[A] = + if i == n then LazyNil + else gen(i) #: recur(i + 1) + recur(0) + +class Ex1 extends Exception + +def problem = + try // error + tabulate(10) { i => + if i > 9 then throw Ex1() + i * i + } + catch case ex: Ex1 => LazyNil + + + + diff --git a/tests/pos-custom-args/captures/i13816.scala b/tests/pos-custom-args/captures/i13816.scala index b8f9db405188..235afef35f1c 100644 --- a/tests/pos-custom-args/captures/i13816.scala +++ b/tests/pos-custom-args/captures/i13816.scala @@ -27,10 +27,10 @@ def foo5(i: Int)(using CanThrow[Ex1])(using CanThrow[Ex2]): Unit = def foo6(i: Int)(using CanThrow[Ex1 | Ex2]): Unit = if i > 0 then throw new Ex1 else throw new Ex2 -def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex2 = +def foo7(i: Int)(using CanThrow[Ex1]): Unit throws Ex1 | Ex2 = if i > 0 then throw new Ex1 else throw new Ex2 -def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex1 = +def foo8(i: Int)(using CanThrow[Ex2]): Unit throws Ex2 | Ex1 = if i > 0 then throw new Ex1 else throw new Ex2 def test(): Unit = diff --git a/tests/pos-custom-args/captures/lazylists-exceptions.scala b/tests/pos-custom-args/captures/lazylists-exceptions.scala index b9b303118358..96b179e564c3 100644 --- a/tests/pos-custom-args/captures/lazylists-exceptions.scala +++ b/tests/pos-custom-args/captures/lazylists-exceptions.scala @@ -52,6 +52,12 @@ extension [A](x: A) def lazyCons[A](x: A, xs1: => {*} LazyList[A]): {xs1} LazyList[A] = LazyCons(x, () => xs1) +def tabulate[A](n: Int)(gen: Int => A) = + def recur(i: Int): {gen} LazyList[A] = + if i == n then LazyNil + else gen(i) #: recur(i + 1) + recur(0) + class Ex1 extends Exception class Ex2 extends Exception @@ -72,4 +78,9 @@ def test(using cap1: CanThrow[Ex1], cap2: CanThrow[Ex2]) = def x2 = x1.concat(xs.map(g).filter(_ > 0)) def x2c: {cap1, cap2} LazyList[Int] = x2 + val x3 = tabulate(10) { i => + if i > 9 then throw Ex1() + i * i + } + val x3c: {cap1} LazyList[Int] = x3