Skip to content

Commit

Permalink
Merge pull request #15570 from dwijnand/gadt/poly
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand authored Jul 7, 2022
2 parents 9d07d52 + 34c2918 commit 1724d84
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
28 changes: 21 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ object desugar {
.withSpan(original.span.withPoint(named.span.start))

/** Main desugaring method */
def apply(tree: Tree)(using Context): Tree = {
def apply(tree: Tree, pt: Type = NoType)(using Context): Tree = {

/** Create tree for for-comprehension `<for (enums) do body>` or
* `<for (enums) yield body>` where mapName and flatMapName are chosen
Expand Down Expand Up @@ -1698,11 +1698,11 @@ object desugar {
}
}

def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match {
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
case Parens(body1) =>
makePolyFunction(targs, body1)
makePolyFunction(targs, body1, pt)
case Block(Nil, body1) =>
makePolyFunction(targs, body1)
makePolyFunction(targs, body1, pt)
case Function(vargs, res) =>
assert(targs.nonEmpty)
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
Expand All @@ -1726,12 +1726,26 @@ object desugar {
}
else {
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.

def typeTree(tp: Type) = tp match
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
var bail = false
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
case tp: TypeRef => ref(tp)
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
val mapped = mapper(mt.resultType, topLevel = true)
if bail then TypeTree() else mapped
case _ => TypeTree()

val applyVParams = vargs.asInstanceOf[List[ValDef]]
.map(varg => varg.withAddedFlags(mods.flags | Param))
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res))
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
))
}
case _ =>
Expand All @@ -1753,7 +1767,7 @@ object desugar {

val desugared = tree match {
case PolyFunction(targs, body) =>
makePolyFunction(targs, body) orElse tree
makePolyFunction(targs, body, pt) orElse tree
case SymbolLit(str) =>
Apply(
ref(defn.ScalaSymbolClass.companionModule.termRef),
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2872,7 +2872,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

typedTypeOrClassDef
case tree: untpd.Labeled => typedLabeled(tree)
case _ => typedUnadapted(desugar(tree), pt, locked)
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
}
}

Expand Down Expand Up @@ -2925,7 +2925,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case tree: untpd.Splice => typedSplice(tree, pt)
case tree: untpd.MacroTree => report.error("Unexpected macro", tree.srcPos); tpd.nullLiteral // ill-formed code may reach here
case tree: untpd.Hole => typedHole(tree, pt)
case _ => typedUnadapted(desugar(tree), pt, locked)
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
}

try
Expand Down
8 changes: 8 additions & 0 deletions tests/pos/i15554.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
enum PingMessage[Response]:
case Ping(from: String) extends PingMessage[String]

val pongBehavior: [O] => (Unit, PingMessage[O]) => (Unit, O) =
[P] =>
(state: Unit, msg: PingMessage[P]) =>
msg match
case PingMessage.Ping(from) => ((), s"Pong from $from")

0 comments on commit 1724d84

Please sign in to comment.