Skip to content

Commit

Permalink
Implement individual erased parameters (#16507)
Browse files Browse the repository at this point in the history
### Syntax

Erased parameters in a method / lambda comes with an `erased` modifier
before its name:
```scala
def erasedSecondParam(x: Int, erased y: Int): Int = x
type EraseSecondParam[T, U] = (T, erased U) => T

val esp: EraseSecondParam[Int, Int] = (x, erased y) => erasedSecondParam(x, y)
```
This is a breaking change, as previously erased methods / functions with
multiple parameters now only have its first parameter erased.

### Semantics

`[Impure][Contextual]ErasedFunctionN` traits are no longer available.
Instead, erased function values are denoted by refining the
`scala.runtime.ErasedFunction` trait:
```scala
type Int_EInt = (Int, erased Int) => Int
// is equivalent to
type Int_EInt2 = scala.runtime.ErasedFunction {
  def apply(x$0: Int, erased x$1: Int): Int
}
```

They are subsequently compiled (during Erasure) into
`[Contextual]FunctionM` where `M` is the number of non-erased
parameters.

### Erased Classes 

Any parameter that is an instance of an erased class is automatically
erased. This is different from before, where the parameters are erased
only if all parameters are instances of erased classes.
  • Loading branch information
natsukagami authored Mar 3, 2023
2 parents 6ea3ea6 + 0f7c3ab commit e422066
Show file tree
Hide file tree
Showing 71 changed files with 879 additions and 403 deletions.
29 changes: 26 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1498,10 +1498,10 @@ object desugar {
case vd: ValDef => vd
}

def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = {
val mods = if (isErased) Given | Erased else Given
def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
val mods = Given
val params = makeImplicitParameters(formals, mods)
FunctionWithMods(params, body, Modifiers(mods))
FunctionWithMods(params, body, Modifiers(mods), erasedParams)
}

private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = {
Expand Down Expand Up @@ -1834,6 +1834,7 @@ object desugar {
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
case _ =>
annotate(tpnme.retains, parent)
case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt)
}
desugared.withSpan(tree.span)
}
Expand Down Expand Up @@ -1909,6 +1910,28 @@ object desugar {
TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait)
}

/** Ensure the given function tree use only ValDefs for parameters.
* For example,
* FunctionWithMods(List(TypeTree(A), TypeTree(B)), body, mods, erasedParams)
* gets converted to
* FunctionWithMods(List(ValDef(x$1, A), ValDef(x$2, B)), body, mods, erasedParams)
*/
def makeFunctionWithValDefs(tree: Function, pt: Type)(using Context): Function = {
val Function(args, result) = tree
args match {
case (_ : ValDef) :: _ => tree // ValDef case can be easily handled
case _ if !ctx.mode.is(Mode.Type) => tree
case _ =>
val applyVParams = args.zipWithIndex.map {
case (p, n) => makeSyntheticParameter(n + 1, p)
}
tree match
case tree: FunctionWithMods =>
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams)
case _ => untpd.Function(applyVParams, result)
}
}

/** Returns list of all pattern variables, possibly with their types,
* without duplicates
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
&& tree.isTerm
&& {
val qualType = tree.qualifier.tpe
hasRefinement(qualType) && !qualType.derivesFrom(defn.PolyFunctionClass)
hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType)
}
def loop(tree: Tree): Boolean = tree match
case TypeApply(fun, _) =>
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
// If `isParamDependent == false`, the value of `previousParamRefs` is not used.
if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN

def valueParam(name: TermName, origInfo: Type): TermSymbol =
def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol =
val maybeImplicit =
if tp.isContextualMethod then Given
else if tp.isImplicitMethod then Implicit
else EmptyFlags
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags
val maybeErased = if isErased then Erased else EmptyFlags

def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)

Expand All @@ -283,7 +283,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
case nil =>
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
(rtp, vparams :: paramss)
case _ =>
Expand Down Expand Up @@ -1140,10 +1140,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, isErased) =>
case defn.ContextFunctionType(argTypes, resType, _) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
MethodType.companion(isContextual = true)(argTypes, resType),
coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
Expand Down
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
override def isType: Boolean = body.isType
}

/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
extends Function(args, body)
/** A function type or closure with `implicit` or `given` modifiers and information on which parameters are `erased` */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val erasedParams: List[Boolean])(implicit @constructorOnly src: SourceFile)
extends Function(args, body) {
assert(args.length == erasedParams.length)

def hasErasedParams = erasedParams.contains(true)
}

/** A polymorphic function type */
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ extension (tp: Type)
defn.FunctionType(
fname.functionArity,
isContextual = fname.isContextFunction,
isErased = fname.isErasedFunction,
isImpure = true).appliedTo(args)
case _ =>
tp
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ class CheckCaptures extends Recheck, SymTransformer:
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing {
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual, isErased) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual, isErased)
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual)
}
else
super.recheckApply(tree, pt) match
Expand Down Expand Up @@ -430,7 +430,7 @@ class CheckCaptures extends Recheck, SymTransformer:
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _, _)
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `*`
Expand Down Expand Up @@ -598,18 +598,18 @@ class CheckCaptures extends Recheck, SymTransformer:
//println(i"check conforms $actual1 <<< $expected1")
super.checkConformsExpr(actual1, expected1, tree)

private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType)
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual)(args, resultType)
.toFunctionType(isJava = false, alwaysDependent = true)

/** Turn `expected` into a dependent function when `actual` is dependent. */
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
def recur(expected: Type): Type = expected.dealias match
case expected @ CapturingType(eparent, refs) =>
CapturingType(recur(eparent), refs, boxed = expected.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual, isErased)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual, isErased)
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case _ =>
expected
Expand Down Expand Up @@ -675,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
Expand Down Expand Up @@ -739,7 +739,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
Expand Down Expand Up @@ -962,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import transform.Recheck.*
import CaptureSet.IdentityCaptRefMap
import Synthetics.isExcluded
import util.Property
import dotty.tools.dotc.core.Annotations.Annotation

/** A tree traverser that prepares a compilation unit to be capture checked.
* It does the following:
Expand All @@ -38,7 +39,6 @@ extends tpd.TreeTraverser:
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
MethodType.companion(
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
)(argTypes, resType)
.toFunctionType(isJava = false, alwaysDependent = true)

Expand All @@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo) if defn.isFunctionType(tp1) =>
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
Expand Down Expand Up @@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
else tp
Expand Down Expand Up @@ -260,7 +260,13 @@ extends tpd.TreeTraverser:
private def expandThrowsAlias(tp: Type)(using Context) = tp match
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
defn.FunctionOf(
AnnotatedType(
defn.CanThrowClass.typeRef.appliedTo(exc),
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
res,
isContextual = true
)
case _ => tp

private def expandThrowsAliases(using Context) = new TypeMap:
Expand Down Expand Up @@ -323,7 +329,7 @@ extends tpd.TreeTraverser:
args.last, CaptureSet.empty, currentCs ++ outerCs)
tp.derivedAppliedType(tycon1, args1 :+ resType1)
tp1.capturing(outerCs)
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
.capturing(outerCs)
case _ =>
Expand Down
Loading

0 comments on commit e422066

Please sign in to comment.