Skip to content

Commit

Permalink
scala#20105: Adding a warning to the case where nested named definiti…
Browse files Browse the repository at this point in the history
…ons contain non-tail recursive calls.

Code will now compile where a child def calls the parent def in a non-tail position (with the warning).
Code will no longer compile if all calls to a @tailrec method are in named child methods (as these do not tail recurse).
  • Loading branch information
Lucy Martin committed Jun 19, 2024
1 parent 133c14a commit 01ada74
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case ContextBoundCompanionNotValueID // errorNumber: 196
case InlinedAnonClassWarningID // errorNumber: 197
case UnusedSymbolID // errorNumber: 198
case TailrecNestedCallID //errorNumber: 199

def errorNumber = ordinal - 1

Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context)
def explain(using Context) = ""
}

class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context)
extends SyntaxMsg(TailrecNestedCallID) {
def msg(using Context) = {
s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}"
}

def explain(using Context) =
"""Tail recursion is only validated and optimised directly in the definition.
|Any calls to the recursive method via an inner def cannot be validated as
|tail recursive, nor optimised if they are. To enable tail recursion from
|inner calls, mark the inner def as inline.
|""".stripMargin
}

class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context)
extends Message(FailureToEliminateExistentialID) {
def kind = MessageKind.Compatibility
Expand Down
31 changes: 29 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,23 @@ class TailRec extends MiniPhase {
assert(false, "We should never have gotten inside a pattern")
tree

case tree: ValOrDefDef =>
case tree: ValDef =>
if (isMandatory) noTailTransform(tree.rhs)
tree

case tree: DefDef =>
if (isMandatory)
if (tree.symbol.is(Synthetic))
noTailTransform(tree.rhs)
else
// We can't tail recurse through nested definitions, so don't want to propagate to child nodes
// We don't want to fail if there is a call that would recurse (as this would be a non self recurse), so don't
// want to call noTailTransform
// We can however warn in this case, as its likely in this situation that someone would expect a tail
// recursion optimization and enabling this to optimise would be a simple case of inlining the inner method
new NestedTailRecAlerter(method, tree.symbol).traverse(tree)
tree

case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
tree

Expand All @@ -446,14 +459,28 @@ class TailRec extends MiniPhase {

case Return(expr, from) =>
val fromSym = from.symbol
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
val inTailPosition = tailPositionLabeledSyms.contains(fromSym) // Label returns are only tail if the label is in tail position
|| (fromSym eq method) // Method returns are only tail if we are looking at the original method
cpy.Return(tree)(transform(expr, inTailPosition), from)

case _ =>
super.transform(tree)
}
}
}

class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser {
override def traverse(tree: tpd.Tree)(using Context): Unit =
tree match {
case a: Apply =>
if (a.fun.symbol eq method) {
report.warning(new TailrecNestedCall(method, inner), a.srcPos)
}
traverseChildren(tree)
case _ =>
traverseChildren(tree)
}
}
}

object TailRec {
Expand Down
10 changes: 10 additions & 0 deletions tests/neg/i20105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- [E199] Syntax Warning: tests/neg/i20105.scala:6:9 -------------------------------------------------------------------
6 | foo()
| ^^^^^
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
|
| longer explanation available when compiling with `-explain`
-- [E097] Syntax Error: tests/neg/i20105.scala:3:4 ---------------------------------------------------------------------
3 |def foo(): Unit = // error
| ^
| TailRec optimisation not applicable, method foo contains no recursive calls
9 changes: 9 additions & 0 deletions tests/neg/i20105.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.annotation.tailrec
@tailrec
def foo(): Unit = // error
def bar(): Unit =
if (???)
foo()
else
bar()
bar()
6 changes: 4 additions & 2 deletions tests/neg/i5397.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ object Test {
rec3 // error: not in tail position
})

@tailrec def rec4: Unit = {
def local = rec4 // error: not in tail position
// This is technically not breaching tail recursion as rec4 does not call itself, local does
// This instead fails due to having no tail recursion at all
@tailrec def rec4: Unit = { // error: no recursive calls
def local = rec4
}

@tailrec def rec5: Int = {
Expand Down
6 changes: 6 additions & 0 deletions tests/warn/i20105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- [E199] Syntax Warning: tests/warn/i20105.scala:6:9 ------------------------------------------------------------------
6 | foo() // warn
| ^^^^^
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
|
| longer explanation available when compiling with `-explain`
10 changes: 10 additions & 0 deletions tests/warn/i20105.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.annotation.tailrec
@tailrec
def foo(): Unit =
def bar(): Unit =
if (???)
foo() // warn
else
bar()
bar()
foo()

0 comments on commit 01ada74

Please sign in to comment.