Skip to content

Commit

Permalink
Reimplement support for type aliases in SAM types
Browse files Browse the repository at this point in the history
This was dropped in scala#18201 which restricted SAM types to valid parent types,
but it turns out that there is code in the wild that relies on refinements
being allowed here.

To support this properly, we had to enhance ExpandSAMs to move refinements into
type members to pass Ycheck (previous Scala 3 releases would accept the code in
tests/run/i18315.scala but fail Ycheck).

Fixes scala#18315.
  • Loading branch information
smarter authored and WojciechMazur committed Sep 5, 2024
1 parent 419cfe6 commit e6cccca
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 45 deletions.
23 changes: 13 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,24 +344,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** An anonymous class
*
* new parents { forwarders }
* new parents { termForwarders; typeAliases }
*
* where `forwarders` contains forwarders for all functions in `fns`.
* @param parents a non-empty list of class types
* @param fns a non-empty of functions for which forwarders should be defined in the class.
* The class has the same owner as the first function in `fns`.
* Its position is the union of all functions in `fns`.
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
*
* The class has the same owner as the first function in `termForwarders`.
* Its position is the union of all symbols in `termForwarders`.
*/
def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = {
AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls =>
def forwarder(fn: TermSymbol, name: TermName) = {
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
def forwarder(name: TermName, fn: TermSymbol) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
for overridden <- fwdMeth.allOverriddenSymbols do
if overridden.is(Extension) then fwdMeth.setFlag(Extension)
if !overridden.is(Deferred) then fwdMeth.setFlag(Override)
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
}
fns.lazyZip(methNames).map(forwarder)
termForwarders.map((name, sym) => forwarder(name, sym)) ++
typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered))
}
}

Expand Down
59 changes: 38 additions & 21 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5534,13 +5534,16 @@ object Types extends TypeUtils {
* and PolyType not allowed!) according to `possibleSamMethods`.
* - can be instantiated without arguments or with just () as argument.
*
* Additionally, a SAM type may contain type aliases refinements if they refine
* an existing type member.
*
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
* type of the single abstract method and `samParent` is a subtype of the matched
* SAM type which has been stripped of wildcards to turn it into a valid parent
* type.
*/
object SAMType {
/** If possible, return a type which is both a subtype of `origTp` and a type
/** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type
* application of `samClass` where none of the type arguments are
* wildcards (thus making it a valid parent type), otherwise return
* NoType.
Expand Down Expand Up @@ -5570,27 +5573,41 @@ object Types extends TypeUtils {
* we arbitrarily pick the upper-bound.
*/
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
val tp = origTp.baseType(samClass)
val tp0 = origTp.baseType(samClass)

/** Copy type aliases refinements to `toTp` from `fromTp` */
def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match
case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists =>
val parent1 = withRefinements(toType, fromParent)
RefinedType(toType, name, info)
case _ => toType
val tp = withRefinements(tp0, origTp)

if !(tp <:< origTp) then NoType
else tp match
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
vmap.recordLocalVariance(tp.symbol, variance)
case _ =>
foldOver(vmap, t)
val vmap = accu(VarianceMap.empty, samMeth.info)
val tparams = tycon.typeParamSymbols
val args1 = args.zipWithConserve(tparams):
case (arg @ TypeBounds(lo, hi), tparam) =>
val v = vmap.computedVariance(tparam)
if v.uncheckedNN < 0 then lo
else hi
case (arg, _) => arg
tp.derivedAppliedType(tycon, args1)
case _ =>
tp
else
def approxWildcardArgs(tp: Type): Type = tp match
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
vmap.recordLocalVariance(tp.symbol, variance)
case _ =>
foldOver(vmap, t)
val vmap = accu(VarianceMap.empty, samMeth.info)
val tparams = tycon.typeParamSymbols
val args1 = args.zipWithConserve(tparams):
case (arg @ TypeBounds(lo, hi), tparam) =>
val v = vmap.computedVariance(tparam)
if v.uncheckedNN < 0 then lo
else hi
case (arg, _) => arg
tp.derivedAppliedType(tycon, args1)
case tp @ RefinedType(parent, name, info) =>
tp.derivedRefinedType(approxWildcardArgs(parent), name, info)
case _ =>
tp
approxWildcardArgs(tp)
end samParent

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
Expand Down
33 changes: 19 additions & 14 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import core.*
import Scopes.newScope
import Contexts.*, Symbols.*, Types.*, Flags.*, Decorators.*, StdNames.*, Constants.*
import MegaPhase.*
import Names.TypeName
import SymUtils.*
import NullOpsDecorator.*
import ast.untpd

Expand Down Expand Up @@ -50,16 +52,28 @@ class ExpandSAMs extends MiniPhase:
case tpe if defn.isContextFunctionType(tpe) =>
tree
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
val tpe1 = checkRefinements(tpe, fn)
toPartialFunction(tree, tpe1)
toPartialFunction(tree, tpe)
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
checkRefinements(tpe, fn)
tree
case tpe =>
val tpe1 = checkRefinements(tpe.stripNull, fn)
// A SAM type is allowed to have type aliases refinements (see
// SAMType#samParent) which must be converted into type members if
// the closure is desugared into a class.
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
def collectAndStripRefinements(tp: Type): Type = tp match
case RefinedType(parent, name, info: TypeAlias) =>
val res = collectAndStripRefinements(parent)
refinements += ((name.asTypeName, info))
res
case _ => tp
val tpe1 = collectAndStripRefinements(tpe)
val Seq(samDenot) = tpe1.possibleSamMethods
cpy.Block(tree)(stats,
AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil))
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList
)
)
}
case _ =>
tree
Expand Down Expand Up @@ -170,13 +184,4 @@ class ExpandSAMs extends MiniPhase:
List(isDefinedAtDef, applyOrElseDef)
}
}

private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {
case RefinedType(parent, name, _) =>
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
report.error(em"Lambda does not define $name", tree.srcPos)
checkRefinements(parent, tree)
case tpe =>
tpe
}
end ExpandSAMs
15 changes: 15 additions & 0 deletions tests/run/i18315.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
trait Sam1:
type T
def apply(x: T): T

trait Sam2:
var x: Int = 1 // To force anonymous class generation
type T
def apply(x: T): T

object Test:
def main(args: Array[String]): Unit =
val s1: Sam1 { type T = String } = x => x.trim
s1.apply("foo")
val s2: Sam2 { type T = Int } = x => x + 1
s2.apply(1)

0 comments on commit e6cccca

Please sign in to comment.