diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index dd83a52ef40b..5adbef96ec89 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -344,24 +344,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { /** An anonymous class * - * new parents { forwarders } + * new parents { termForwarders; typeAliases } * - * where `forwarders` contains forwarders for all functions in `fns`. - * @param parents a non-empty list of class types - * @param fns a non-empty of functions for which forwarders should be defined in the class. - * The class has the same owner as the first function in `fns`. - * Its position is the union of all functions in `fns`. + * @param parents a non-empty list of class types + * @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to. + * @param typeMembers a possibly-empty list of type members specified by their name and their right hand side. + * + * The class has the same owner as the first function in `termForwarders`. + * Its position is the union of all symbols in `termForwarders`. */ - def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = { - AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls => - def forwarder(fn: TermSymbol, name: TermName) = { + def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)], + typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = { + AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls => + def forwarder(name: TermName, fn: TermSymbol) = { val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm for overridden <- fwdMeth.allOverriddenSymbols do if overridden.is(Extension) then fwdMeth.setFlag(Extension) if !overridden.is(Deferred) then fwdMeth.setFlag(Override) DefDef(fwdMeth, ref(fn).appliedToArgss(_)) } - fns.lazyZip(methNames).map(forwarder) + termForwarders.map((name, sym) => forwarder(name, sym)) ++ + typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered)) } } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 9d0b3d839619..3c887bfc076d 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -5534,13 +5534,16 @@ object Types extends TypeUtils { * and PolyType not allowed!) according to `possibleSamMethods`. * - can be instantiated without arguments or with just () as argument. * + * Additionally, a SAM type may contain type aliases refinements if they refine + * an existing type member. + * * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the * type of the single abstract method and `samParent` is a subtype of the matched * SAM type which has been stripped of wildcards to turn it into a valid parent * type. */ object SAMType { - /** If possible, return a type which is both a subtype of `origTp` and a type + /** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type * application of `samClass` where none of the type arguments are * wildcards (thus making it a valid parent type), otherwise return * NoType. @@ -5570,27 +5573,41 @@ object Types extends TypeUtils { * we arbitrarily pick the upper-bound. */ def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type = - val tp = origTp.baseType(samClass) + val tp0 = origTp.baseType(samClass) + + /** Copy type aliases refinements to `toTp` from `fromTp` */ + def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match + case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists => + val parent1 = withRefinements(toType, fromParent) + RefinedType(toType, name, info) + case _ => toType + val tp = withRefinements(tp0, origTp) + if !(tp <:< origTp) then NoType - else tp match - case tp @ AppliedType(tycon, args) if tp.hasWildcardArg => - val accu = new TypeAccumulator[VarianceMap[Symbol]]: - def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match - case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) => - vmap.recordLocalVariance(tp.symbol, variance) - case _ => - foldOver(vmap, t) - val vmap = accu(VarianceMap.empty, samMeth.info) - val tparams = tycon.typeParamSymbols - val args1 = args.zipWithConserve(tparams): - case (arg @ TypeBounds(lo, hi), tparam) => - val v = vmap.computedVariance(tparam) - if v.uncheckedNN < 0 then lo - else hi - case (arg, _) => arg - tp.derivedAppliedType(tycon, args1) - case _ => - tp + else + def approxWildcardArgs(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) if tp.hasWildcardArg => + val accu = new TypeAccumulator[VarianceMap[Symbol]]: + def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match + case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) => + vmap.recordLocalVariance(tp.symbol, variance) + case _ => + foldOver(vmap, t) + val vmap = accu(VarianceMap.empty, samMeth.info) + val tparams = tycon.typeParamSymbols + val args1 = args.zipWithConserve(tparams): + case (arg @ TypeBounds(lo, hi), tparam) => + val v = vmap.computedVariance(tparam) + if v.uncheckedNN < 0 then lo + else hi + case (arg, _) => arg + tp.derivedAppliedType(tycon, args1) + case tp @ RefinedType(parent, name, info) => + tp.derivedRefinedType(approxWildcardArgs(parent), name, info) + case _ => + tp + approxWildcardArgs(tp) + end samParent def samClass(tp: Type)(using Context): Symbol = tp match case tp: ClassInfo => diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index 47c2dad1d61b..4347cca7f9d9 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -6,6 +6,8 @@ import core.* import Scopes.newScope import Contexts.*, Symbols.*, Types.*, Flags.*, Decorators.*, StdNames.*, Constants.* import MegaPhase.* +import Names.TypeName +import SymUtils.* import NullOpsDecorator.* import ast.untpd @@ -50,16 +52,28 @@ class ExpandSAMs extends MiniPhase: case tpe if defn.isContextFunctionType(tpe) => tree case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) => - val tpe1 = checkRefinements(tpe, fn) - toPartialFunction(tree, tpe1) + toPartialFunction(tree, tpe) case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) => - checkRefinements(tpe, fn) tree case tpe => - val tpe1 = checkRefinements(tpe.stripNull, fn) + // A SAM type is allowed to have type aliases refinements (see + // SAMType#samParent) which must be converted into type members if + // the closure is desugared into a class. + val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]() + def collectAndStripRefinements(tp: Type): Type = tp match + case RefinedType(parent, name, info: TypeAlias) => + val res = collectAndStripRefinements(parent) + refinements += ((name.asTypeName, info)) + res + case _ => tp + val tpe1 = collectAndStripRefinements(tpe) val Seq(samDenot) = tpe1.possibleSamMethods cpy.Block(tree)(stats, - AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil)) + AnonClass(List(tpe1), + List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm), + refinements.toList + ) + ) } case _ => tree @@ -170,13 +184,4 @@ class ExpandSAMs extends MiniPhase: List(isDefinedAtDef, applyOrElseDef) } } - - private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match { - case RefinedType(parent, name, _) => - if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement - report.error(em"Lambda does not define $name", tree.srcPos) - checkRefinements(parent, tree) - case tpe => - tpe - } end ExpandSAMs diff --git a/tests/run/i18315.scala b/tests/run/i18315.scala new file mode 100644 index 000000000000..85824920efbd --- /dev/null +++ b/tests/run/i18315.scala @@ -0,0 +1,15 @@ +trait Sam1: + type T + def apply(x: T): T + +trait Sam2: + var x: Int = 1 // To force anonymous class generation + type T + def apply(x: T): T + +object Test: + def main(args: Array[String]): Unit = + val s1: Sam1 { type T = String } = x => x.trim + s1.apply("foo") + val s2: Sam2 { type T = Int } = x => x + 1 + s2.apply(1)