From cf034f03514c10ed2ad09d20d1d8f32570439359 Mon Sep 17 00:00:00 2001
From: Nicolas Stucki <nicolas.stucki@gmail.com>
Date: Tue, 22 Aug 2023 15:22:54 +0200
Subject: [PATCH 1/5] Handle dependent context functions

Add `FunctionTypeOfMethod` extractor that matches any kind of function
and return its method type.

We use this extractor instead of `ContextFunctionType` to all of
 * `ContextFunctionN[...]`
 * `ContextFunctionN[...] { def apply(using ...): R }` where `R` might
    be dependent on the parameters.
 * `PolyFunction { def apply(using ...): R }` where `R` might
    be dependent on the parameters. Currently this one would have at
    least one erased parameter.
---
 compiler/src/dotty/tools/dotc/ast/TreeInfo.scala |  2 +-
 .../src/dotty/tools/dotc/core/Definitions.scala  | 16 ++++++++++++++++
 .../dotc/transform/ContextFunctionResults.scala  | 16 ++++++++--------
 3 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
index 4aaef28b9e1e..9751e8272858 100644
--- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
+++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
@@ -990,7 +990,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
   def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = {
     def isStructuralTermSelect(tree: Select) =
       def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match
-        case defn.PolyFunctionOf(_) =>
+        case defn.FunctionTypeOfMethod(_) =>
           false
         case RefinedType(parent, rname, rinfo) =>
           rname == tree.name || hasRefinement(parent)
diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala
index 22a49a760e57..dfa43f7407eb 100644
--- a/compiler/src/dotty/tools/dotc/core/Definitions.scala
+++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala
@@ -1108,6 +1108,22 @@ class Definitions {
     //  - .linkedClass: the ClassSymbol of the enumeration (class E)
     sym.owner.linkedClass.typeRef
 
+  object FunctionTypeOfMethod {
+    /** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
+     *  Extracts the method type type and apply info.
+     */
+    def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
+      ft match
+        case RefinedType(parent, nme.apply, mt: MethodOrPoly)
+        if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
+          Some(mt)
+        case FunctionOf(argTypes, resultType, isContextual) =>
+          val methodType = if isContextual then ContextualMethodType else MethodType
+          Some(methodType(argTypes, resultType))
+        case _ => None
+    }
+  }
+
   object FunctionOf {
     def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type =
       val mt = MethodType.companion(isContextual, false)(args, resultType)
diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala
index 01a77427698a..1b1d78182f0f 100644
--- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala
+++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala
@@ -58,8 +58,8 @@ object ContextFunctionResults:
    */
   def contextResultsAreErased(sym: Symbol)(using Context): Boolean =
     def allErased(tp: Type): Boolean = tp.dealias match
-      case defn.ContextFunctionType(argTpes, resTpe) =>
-        argTpes.forall(_.hasAnnotation(defn.ErasedParamAnnot)) && allErased(resTpe)
+      case ft @ defn.FunctionTypeOfMethod(mt: MethodType) if mt.isContextualMethod =>
+        mt.nonErasedParamCount == 0 && allErased(mt.resType)
       case _ => true
     contextResultCount(sym) > 0 && allErased(sym.info.finalResultType)
 
@@ -68,13 +68,13 @@ object ContextFunctionResults:
    */
   def integrateContextResults(tp: Type, crCount: Int)(using Context): Type =
     if crCount == 0 then tp
-    else tp match
+    else tp.dealias match
       case ExprType(rt) =>
         integrateContextResults(rt, crCount)
       case tp: MethodOrPoly =>
         tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount))
-      case defn.ContextFunctionType(argTypes, resType) =>
-        MethodType(argTypes, integrateContextResults(resType, crCount - 1))
+      case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
+        mt.derivedLambdaType(resType = integrateContextResults(mt.resType, crCount - 1))
 
   /** The total number of parameters of method `sym`, not counting
    *  erased parameters, but including context result parameters.
@@ -101,7 +101,7 @@ object ContextFunctionResults:
     def recur(tp: Type, n: Int): Type =
       if n == 0 then tp
       else tp match
-        case defn.ContextFunctionType(_, resTpe) => recur(resTpe, n - 1)
+        case defn.FunctionTypeOfMethod(mt) => recur(mt.resType, n - 1)
     recur(meth.info.finalResultType, depth)
 
   /** Should selection `tree` be eliminated since it refers to an `apply`
@@ -115,8 +115,8 @@ object ContextFunctionResults:
     else tree match
       case Select(qual, name) =>
         if name == nme.apply then
-          qual.tpe match
-            case defn.ContextFunctionType(_, _) =>
+          qual.tpe.nn.dealias match
+            case defn.FunctionTypeOfMethod(mt) if mt.isContextualMethod =>
               integrateSelect(qual, n + 1)
             case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs
               integrateSelect(qual, n + 1)

From b0eaf17b8be525fe3fa91afc5b31039cc74a6d52 Mon Sep 17 00:00:00 2001
From: Nicolas Stucki <nicolas.stucki@gmail.com>
Date: Wed, 30 Aug 2023 10:42:20 +0200
Subject: [PATCH 2/5] Inline FunctionOf in FunctionTypeOfMethod and optimize

---
 compiler/src/dotty/tools/dotc/core/Definitions.scala | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala
index dfa43f7407eb..d7b1f290cb1c 100644
--- a/compiler/src/dotty/tools/dotc/core/Definitions.scala
+++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala
@@ -1117,10 +1117,14 @@ class Definitions {
         case RefinedType(parent, nme.apply, mt: MethodOrPoly)
         if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
           Some(mt)
-        case FunctionOf(argTypes, resultType, isContextual) =>
-          val methodType = if isContextual then ContextualMethodType else MethodType
-          Some(methodType(argTypes, resultType))
-        case _ => None
+        case _ =>
+          val tsym = ft.typeSymbol
+          val targs = ft.argInfos
+          if targs.nonEmpty && isFunctionSymbol(tsym) && ft.isRef(tsym) then
+            val isContextual = tsym.name.isContextFunction
+            val methodType = if isContextual then ContextualMethodType else MethodType
+            Some(methodType(targs.init, targs.last))
+          else None
     }
   }
 

From 94c01ffc784f3e74655ae7bb612d1a0312110b03 Mon Sep 17 00:00:00 2001
From: Nicolas Stucki <nicolas.stucki@gmail.com>
Date: Wed, 30 Aug 2023 11:05:13 +0200
Subject: [PATCH 3/5] Optimize `FunctionTypeOfMethod`

---
 compiler/src/dotty/tools/dotc/core/Definitions.scala | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala
index d7b1f290cb1c..ee7262fa389b 100644
--- a/compiler/src/dotty/tools/dotc/core/Definitions.scala
+++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala
@@ -1117,14 +1117,12 @@ class Definitions {
         case RefinedType(parent, nme.apply, mt: MethodOrPoly)
         if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
           Some(mt)
+        case AppliedType(parent, targs) if targs.nonEmpty && isFunctionNType(ft) =>
+          val isContextual = ft.typeSymbol.name.isContextFunction
+          val methodType = if isContextual then ContextualMethodType else MethodType
+          Some(methodType(targs.init, targs.last))
         case _ =>
-          val tsym = ft.typeSymbol
-          val targs = ft.argInfos
-          if targs.nonEmpty && isFunctionSymbol(tsym) && ft.isRef(tsym) then
-            val isContextual = tsym.name.isContextFunction
-            val methodType = if isContextual then ContextualMethodType else MethodType
-            Some(methodType(targs.init, targs.last))
-          else None
+          None
     }
   }
 

From 527dd8e37522ece8a6090abc9ce60c9190d4771a Mon Sep 17 00:00:00 2001
From: Nicolas Stucki <nicolas.stucki@gmail.com>
Date: Wed, 30 Aug 2023 11:11:37 +0200
Subject: [PATCH 4/5] Optimize FunctionTypeOfMethod RefinedType guard

---
 compiler/src/dotty/tools/dotc/core/Definitions.scala | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala
index ee7262fa389b..e4f32ea97772 100644
--- a/compiler/src/dotty/tools/dotc/core/Definitions.scala
+++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala
@@ -1115,7 +1115,7 @@ class Definitions {
     def unapply(ft: Type)(using Context): Option[MethodOrPoly] = {
       ft match
         case RefinedType(parent, nme.apply, mt: MethodOrPoly)
-        if parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(parent) =>
+        if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) =>
           Some(mt)
         case AppliedType(parent, targs) if targs.nonEmpty && isFunctionNType(ft) =>
           val isContextual = ft.typeSymbol.name.isContextFunction

From d5d8273fbeb6c8c303ac76d9c33f7c68511bd01b Mon Sep 17 00:00:00 2001
From: Nicolas Stucki <nicolas.stucki@gmail.com>
Date: Thu, 31 Aug 2023 17:53:34 +0200
Subject: [PATCH 5/5] Remove unnecessary guard for AppliedType

---
 compiler/src/dotty/tools/dotc/core/Definitions.scala | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala
index e4f32ea97772..846a1f68cb79 100644
--- a/compiler/src/dotty/tools/dotc/core/Definitions.scala
+++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala
@@ -1117,7 +1117,7 @@ class Definitions {
         case RefinedType(parent, nme.apply, mt: MethodOrPoly)
         if parent.derivesFrom(defn.PolyFunctionClass) || (mt.isInstanceOf[MethodType] && isFunctionNType(parent)) =>
           Some(mt)
-        case AppliedType(parent, targs) if targs.nonEmpty && isFunctionNType(ft) =>
+        case AppliedType(parent, targs) if isFunctionNType(ft) =>
           val isContextual = ft.typeSymbol.name.isContextFunction
           val methodType = if isContextual then ContextualMethodType else MethodType
           Some(methodType(targs.init, targs.last))