Skip to content

Commit

Permalink
Preventing compilation of a @tailrec method when it does not rewrite,…
Browse files Browse the repository at this point in the history
… but an inner method does

Adding warnings if there is an annotated def at the top level that is referenced from an inner def

Potential options for different handling of defined and implicit inner methods

Changes from PR.
  • Loading branch information
Lucy Martin committed Apr 12, 2024
1 parent adf089b commit a841499
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case UnstableInlineAccessorID // errorNumber: 192
case VolatileOnValID // errorNumber: 193
case ExtensionNullifiedByMemberID // errorNumber: 194
case TailrecNestedCallID //errorNumber: 195

def errorNumber = ordinal - 1

Expand Down
14 changes: 14 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1907,6 +1907,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context)
def explain(using Context) = ""
}

class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context)
extends SyntaxMsg(TailrecNestedCallID) {
def msg(using Context) = {
s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}"
}

def explain(using Context) =
"""Tail recursion is only validated and optimised directly in the definition.
|Any calls to the recursive method via an inner def cannot be validated as
|tail recursive, nor optimised if they are. To enable tail recursion from
|inner calls, mark the inner def as inline.
|""".stripMargin
}

class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context)
extends Message(FailureToEliminateExistentialID) {
def kind = MessageKind.Compatibility
Expand Down
31 changes: 29 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,23 @@ class TailRec extends MiniPhase {
assert(false, "We should never have gotten inside a pattern")
tree

case tree: ValOrDefDef =>
case tree: ValDef =>
if (isMandatory) noTailTransform(tree.rhs)
tree

case tree: DefDef =>
if (isMandatory)
if (tree.symbol.is(Synthetic))
noTailTransform(tree.rhs)
else
// We can't tail recurse through nested definitions, so don't want to propagate to child nodes
// We don't want to fail if there is a call that would recurse (as this would be a non self recurse), so don't
// want to call noTailTransform
// We can however warn in this case, as its likely in this situation that someone would expect a tail
// recursion optimization and enabling this to optimise would be a simple case of inlining the inner method
new NestedTailRecAlerter(method, tree.symbol).traverse(tree)
tree

case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
tree

Expand All @@ -446,14 +459,28 @@ class TailRec extends MiniPhase {

case Return(expr, from) =>
val fromSym = from.symbol
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
val inTailPosition = tailPositionLabeledSyms.contains(fromSym) // Label returns are only tail if the label is in tail position
|| (fromSym eq method) // Method returns are only tail if we are looking at the original method
cpy.Return(tree)(transform(expr, inTailPosition), from)

case _ =>
super.transform(tree)
}
}
}

class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser {
override def traverse(tree: tpd.Tree)(using Context): Unit =
tree match {
case a: Apply =>
if (a.fun.symbol eq method) {
report.warning(new TailrecNestedCall(method, inner), a.srcPos)
}
traverseChildren(tree)
case _ =>
traverseChildren(tree)
}
}
}

object TailRec {
Expand Down
10 changes: 10 additions & 0 deletions tests/neg/i20105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- [E195] Syntax Warning: tests\neg\i20105.scala:6:9 -------------------------------------------------------------------
6 | foo()
| ^^^^^
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
|
| longer explanation available when compiling with `-explain`
-- [E097] Syntax Error: tests\neg\i20105.scala:3:4 ---------------------------------------------------------------------
3 |def foo(): Unit = // error
| ^
| TailRec optimisation not applicable, method foo contains no recursive calls
9 changes: 9 additions & 0 deletions tests/neg/i20105.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.annotation.tailrec
@tailrec
def foo(): Unit = // error
def bar(): Unit =
if (???)
foo()
else
bar()
bar()
6 changes: 4 additions & 2 deletions tests/neg/i5397.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ object Test {
rec3 // error: not in tail position
})

@tailrec def rec4: Unit = {
def local = rec4 // error: not in tail position
// This is technically not breaching tail recursion as rec4 does not call itself, local does
// This instead fails due to having no tail recursion at all
@tailrec def rec4: Unit = { // error: no recursive calls
def local = rec4
}

@tailrec def rec5: Int = {
Expand Down
6 changes: 6 additions & 0 deletions tests/warn/i20105.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- [E195] Syntax Warning: tests\warn\i20105.scala:6:9 ------------------------------------------------------------------
6 | foo() // warn
| ^^^^^
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
|
| longer explanation available when compiling with `-explain`
10 changes: 10 additions & 0 deletions tests/warn/i20105.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import scala.annotation.tailrec
@tailrec
def foo(): Unit =
def bar(): Unit =
if (???)
foo() // warn
else
bar()
bar()
foo()

0 comments on commit a841499

Please sign in to comment.