Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement individual erased parameters #16507

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1498,10 +1498,10 @@ object desugar {
case vd: ValDef => vd
}

def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(using Context): Function = {
val mods = if (isErased) Given | Erased else Given
def makeContextualFunction(formals: List[Tree], body: Tree, erasedParams: List[Boolean])(using Context): Function = {
val mods = Given
val params = makeImplicitParameters(formals, mods)
FunctionWithMods(params, body, Modifiers(mods))
FunctionWithMods(params, body, Modifiers(mods), erasedParams)
}

private def derivedValDef(original: Tree, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) = {
Expand Down Expand Up @@ -1834,6 +1834,7 @@ object desugar {
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
case _ =>
annotate(tpnme.retains, parent)
case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt)
}
desugared.withSpan(tree.span)
}
Expand Down Expand Up @@ -1909,6 +1910,28 @@ object desugar {
TypeDef(tpnme.REFINE_CLASS, impl).withFlags(Trait)
}

/** Ensure the given function tree use only ValDefs for parameters.
* For example,
* FunctionWithMods(List(TypeTree(A), TypeTree(B)), body, mods, erasedParams)
* gets converted to
* FunctionWithMods(List(ValDef(x$1, A), ValDef(x$2, B)), body, mods, erasedParams)
*/
def makeFunctionWithValDefs(tree: Function, pt: Type)(using Context): Function = {
val Function(args, result) = tree
args match {
case (_ : ValDef) :: _ => tree // ValDef case can be easily handled
case _ if !ctx.mode.is(Mode.Type) => tree
case _ =>
val applyVParams = args.zipWithIndex.map {
case (p, n) => makeSyntheticParameter(n + 1, p)
}
tree match
case tree: FunctionWithMods =>
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams)
case _ => untpd.Function(applyVParams, result)
}
}

/** Returns list of all pattern variables, possibly with their types,
* without duplicates
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
&& tree.isTerm
&& {
val qualType = tree.qualifier.tpe
hasRefinement(qualType) && !qualType.derivesFrom(defn.PolyFunctionClass)
hasRefinement(qualType) && !defn.isRefinedFunctionType(qualType)
}
def loop(tree: Tree): Boolean = tree match
case TypeApply(fun, _) =>
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
// If `isParamDependent == false`, the value of `previousParamRefs` is not used.
if isParamDependent then mutable.ListBuffer[TermRef]() else (null: ListBuffer[TermRef] | Null).uncheckedNN

def valueParam(name: TermName, origInfo: Type): TermSymbol =
def valueParam(name: TermName, origInfo: Type, isErased: Boolean): TermSymbol =
val maybeImplicit =
if tp.isContextualMethod then Given
else if tp.isImplicitMethod then Implicit
else EmptyFlags
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags
val maybeErased = if isErased then Erased else EmptyFlags

def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)

Expand All @@ -283,7 +283,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
case nil =>
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
(rtp, vparams :: paramss)
case _ =>
Expand Down Expand Up @@ -1140,10 +1140,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, isErased) =>
case defn.ContextFunctionType(argTypes, resType, _) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
MethodType.companion(isContextual = true)(argTypes, resType),
coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
Expand Down
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
override def isType: Boolean = body.isType
}

/** A function type or closure with `implicit`, `erased`, or `given` modifiers */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers)(implicit @constructorOnly src: SourceFile)
extends Function(args, body)
/** A function type or closure with `implicit` or `given` modifiers and information on which parameters are `erased` */
class FunctionWithMods(args: List[Tree], body: Tree, val mods: Modifiers, val erasedParams: List[Boolean])(implicit @constructorOnly src: SourceFile)
extends Function(args, body) {
assert(args.length == erasedParams.length)

def hasErasedParams = erasedParams.contains(true)
}

/** A polymorphic function type */
case class PolyFunction(targs: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends Tree {
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ extension (tp: Type)
defn.FunctionType(
fname.functionArity,
isContextual = fname.isContextFunction,
isErased = fname.isErasedFunction,
isImpure = true).appliedTo(args)
case _ =>
tp
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ class CheckCaptures extends Recheck, SymTransformer:
mapArgUsing(_.forceBoxStatus(false))
else if meth == defn.Caps_unsafeBoxFunArg then
mapArgUsing {
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual, isErased) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual, isErased)
case defn.FunctionOf(paramtpe :: Nil, restpe, isContectual) =>
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContectual)
}
else
super.recheckApply(tree, pt) match
Expand Down Expand Up @@ -430,7 +430,7 @@ class CheckCaptures extends Recheck, SymTransformer:
block match
case closureDef(mdef) =>
pt.dealias match
case defn.FunctionOf(ptformals, _, _, _)
case defn.FunctionOf(ptformals, _, _)
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
// Redo setup of the anonymous function so that formal parameters don't
// get capture sets. This is important to avoid false widenings to `*`
Expand Down Expand Up @@ -598,18 +598,18 @@ class CheckCaptures extends Recheck, SymTransformer:
//println(i"check conforms $actual1 <<< $expected1")
super.checkConformsExpr(actual1, expected1, tree)

private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean, isErased: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual, isErased = isErased)(args, resultType)
private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type =
MethodType.companion(isContextual = isContextual)(args, resultType)
.toFunctionType(isJava = false, alwaysDependent = true)

/** Turn `expected` into a dependent function when `actual` is dependent. */
private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type =
def recur(expected: Type): Type = expected.dealias match
case expected @ CapturingType(eparent, refs) =>
CapturingType(recur(eparent), refs, boxed = expected.isBoxed)
case expected @ defn.FunctionOf(args, resultType, isContextual, isErased)
case expected @ defn.FunctionOf(args, resultType, isContextual)
if defn.isNonRefinedFunction(expected) && defn.isFunctionType(actual) && !defn.isNonRefinedFunction(actual) =>
val expected1 = toDepFun(args, resultType, isContextual, isErased)
val expected1 = toDepFun(args, resultType, isContextual)
expected1
case _ =>
expected
Expand Down Expand Up @@ -675,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer:

try
val (eargs, eres) = expected.dealias.stripCapturing match
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
case defn.FunctionOf(eargs, eres, _) => (eargs, eres)
case expected: MethodType => (expected.paramInfos, expected.resType)
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
case _ => (aargs.map(_ => WildcardType), WildcardType)
Expand Down Expand Up @@ -739,7 +739,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
adaptFun(actual, args.init, args.last, expected, covariant, insertBox,
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(actual) =>
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
(aargs1, ares1) =>
Expand Down Expand Up @@ -962,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case CapturingType(parent, refs) =>
healCaptureSet(refs)
traverse(parent)
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
traverse(rinfo)
case tp: TermLambda =>
val saved = allowed
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import transform.Recheck.*
import CaptureSet.IdentityCaptRefMap
import Synthetics.isExcluded
import util.Property
import dotty.tools.dotc.core.Annotations.Annotation

/** A tree traverser that prepares a compilation unit to be capture checked.
* It does the following:
Expand All @@ -38,7 +39,6 @@ extends tpd.TreeTraverser:
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
MethodType.companion(
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
)(argTypes, resType)
.toFunctionType(isJava = false, alwaysDependent = true)

Expand All @@ -54,7 +54,7 @@ extends tpd.TreeTraverser:
val boxedRes = recur(res)
if boxedRes eq res then tp
else tp1.derivedAppliedType(tycon, args.init :+ boxedRes)
case tp1 @ RefinedType(_, _, rinfo) if defn.isFunctionType(tp1) =>
case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionOrPolyType(tp1) =>
val boxedRinfo = recur(rinfo)
if boxedRinfo eq rinfo then tp
else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true)
Expand Down Expand Up @@ -231,7 +231,7 @@ extends tpd.TreeTraverser:
tp.derivedAppliedType(tycon1, args1 :+ res1)
else
tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg)))
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
val rinfo1 = apply(rinfo)
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
else tp
Expand Down Expand Up @@ -260,7 +260,13 @@ extends tpd.TreeTraverser:
private def expandThrowsAlias(tp: Type)(using Context) = tp match
case AppliedType(tycon, res :: exc :: Nil) if tycon.typeSymbol == defn.throwsAlias =>
// hard-coded expansion since $throws aliases in stdlib are defined with `?=>` rather than `?->`
defn.FunctionOf(defn.CanThrowClass.typeRef.appliedTo(exc) :: Nil, res, isContextual = true, isErased = true)
defn.FunctionOf(
AnnotatedType(
defn.CanThrowClass.typeRef.appliedTo(exc),
Annotation(defn.ErasedParamAnnot, defn.CanThrowClass.span)) :: Nil,
res,
isContextual = true
)
case _ => tp

private def expandThrowsAliases(using Context) = new TypeMap:
Expand Down Expand Up @@ -323,7 +329,7 @@ extends tpd.TreeTraverser:
args.last, CaptureSet.empty, currentCs ++ outerCs)
tp.derivedAppliedType(tycon1, args1 :+ resType1)
tp1.capturing(outerCs)
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionOrPolyType(tp) =>
propagateDepFunctionResult(mapOver(tp), currentCs ++ outerCs)
.capturing(outerCs)
case _ =>
Expand Down
Loading