Skip to content

Commit

Permalink
Check user defined PolyFunction refinements
Browse files Browse the repository at this point in the history
`PolyFunction` must be refined with an `apply` method that has a single
parameter list with no by-name nor varargs parameters. It may optionally
have type parameters. Some of these restrictions could be lifted later,
but for now these features are not properly handled by the compiler.

Fixes scala#8299
Fixes scala#18302

fixup
  • Loading branch information
nicolasstucki committed Sep 7, 2023
1 parent 5eb73b6 commit 95b6ed8
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 2 deletions.
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1161,10 +1161,12 @@ class Definitions {
Some(mt)
case _ => None

private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidMethodType(info: Type) = info match
case info: MethodType =>
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
!info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list
!info.isVarArgsMethod &&
!info.paramInfos.exists(_.isInstanceOf[ExprType]) // No by-name parameters
case _ => false
info match
case info: PolyType => isValidMethodType(info.resType)
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case tree: ValDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
case tree: DefDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
Expand Down Expand Up @@ -492,6 +494,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
)
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
superAcc.withInvalidCurrentClass(super.transform(tree))
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case _: Quote | _: QuotePattern =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,31 @@ object Checking {
else Feature.checkExperimentalFeature("features", imp.srcPos)
case _ =>
end checkExperimentalImports

/** Checks that PolyFunction only have valid refinements.
*
* It only supports `apply` methods with one parameter list and optional type arguments.
*/
def checkPolyFunctionType(tree: Tree)(using Context): Unit = new TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = tree match
case tree: RefinedTypeTree if tree.tpe.derivesFrom(defn.PolyFunctionClass) =>
if tree.refinements.isEmpty then
reportNoRefinements(tree.srcPos)
tree.refinements.foreach {
case refinement: DefDef if refinement.name != nme.apply =>
report.error("PolyFunction only supports apply method refinements", refinement.srcPos)
case refinement: DefDef if !defn.PolyFunctionOf.isValidPolyFunctionInfo(refinement.tpe.widen) =>
report.error("Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.", refinement.srcPos)
case _ =>
}
case _: RefTree if tree.symbol == defn.PolyFunctionClass =>
reportNoRefinements(tree.srcPos)
case _ =>
traverseChildren(tree)

def reportNoRefinements(pos: SrcPos) =
report.error("PolyFunction refinement must have a refinements of the apply method", pos)
}.traverse(tree)
}

trait Checking {
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i18302b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302b.scala:3:32 ---------------------------------------------------------------------------------
3 |def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
5 changes: 5 additions & 0 deletions tests/neg/i18302b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test = polyFun(1)(2)

def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
new PolyFunction:
def apply(x: Int)(y: Int): Int = x + y
4 changes: 4 additions & 0 deletions tests/neg/i18302c.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302c.scala:4:32 ---------------------------------------------------------------------------------
4 |def polyFun: PolyFunction { def foo(x: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^
| PolyFunction only supports apply method refinements
5 changes: 5 additions & 0 deletions tests/neg/i18302c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import scala.reflect.Selectable.reflectiveSelectable

def test = polyFun.foo(1)
def polyFun: PolyFunction { def foo(x: Int): Int } = // error
new PolyFunction { def foo(x: Int): Int = x + 1 }
4 changes: 4 additions & 0 deletions tests/neg/i18302d.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302d.scala:2:32 ---------------------------------------------------------------------------------
2 |def polyFun: PolyFunction { def apply: Int } = // error
| ^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
3 changes: 3 additions & 0 deletions tests/neg/i18302d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test = polyFun.apply
def polyFun: PolyFunction { def apply: Int } = // error
new PolyFunction { def apply: Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i18302e.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/i18302e.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction { } = // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction refinement must have a refinements of the apply method
-- Error: tests/neg/i18302e.scala:4:15 ---------------------------------------------------------------------------------
4 |def polyFun(f: PolyFunction { }) = () // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction refinement must have a refinements of the apply method
4 changes: 4 additions & 0 deletions tests/neg/i18302e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def polyFun: PolyFunction { } = // error
new PolyFunction { }

def polyFun(f: PolyFunction { }) = () // error
12 changes: 12 additions & 0 deletions tests/neg/i18302f.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Error: tests/neg/i18302f.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction refinement must have a refinements of the apply method
-- Error: tests/neg/i18302f.scala:4:16 ---------------------------------------------------------------------------------
4 |def polyFun2(a: PolyFunction) = () // error
| ^^^^^^^^^^^^
| PolyFunction refinement must have a refinements of the apply method
-- Error: tests/neg/i18302f.scala:6:14 ---------------------------------------------------------------------------------
6 |val polyFun3: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction refinement must have a refinements of the apply method
7 changes: 7 additions & 0 deletions tests/neg/i18302f.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def polyFun: PolyFunction = // error
new PolyFunction { }

def polyFun2(a: PolyFunction) = () // error

val polyFun3: PolyFunction = // error
new PolyFunction { }
6 changes: 6 additions & 0 deletions tests/neg/i18302i.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def polyFun1: Option[PolyFunction] = ??? // error
def polyFun2: PolyFunction & Any = ??? // error
def polyFun3: Any & PolyFunction = ??? // error
def polyFun4: PolyFunction | Any = ??? // error
def polyFun5: Any | PolyFunction = ??? // error
def polyFun6(a: Any | PolyFunction) = ??? // error
5 changes: 5 additions & 0 deletions tests/neg/i18302j.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def polyFunByName: PolyFunction { def apply(thunk: => Int): Int } = // error
new PolyFunction { def apply(thunk: => Int): Int = 1 }

def polyFunVarArgs: PolyFunction { def apply(args: Int*): Int } = // error
new PolyFunction { def apply(thunk: Int*): Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i8299.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package example

object Main {
def main(a: Array[String]): Unit = {
val p: PolyFunction = // error: PolyFunction refinement must have a refinements of the apply method
[A] => (xs: List[A]) => xs.headOption
}
}
4 changes: 4 additions & 0 deletions tests/pos/i18302a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = polyFun(1)

def polyFun: PolyFunction { def apply(x: Int): Int } =
new PolyFunction { def apply(x: Int): Int = x + 1 }

0 comments on commit 95b6ed8

Please sign in to comment.