Skip to content

Commit

Permalink
Merge pull request #14523 from dotty-staging/optimize-mapblock
Browse files Browse the repository at this point in the history
Thread context through block in transforms correctly and efficiently
  • Loading branch information
bishabosha authored Feb 21, 2022
2 parents 334b37a + 3ef8e2f commit 915f4e8
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 36 deletions.
8 changes: 1 addition & 7 deletions compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,7 @@ class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts {
override def transform(tree: Tree)(using Context): Tree = {
try tree match {
case Block(stats, expr) =>
inContext(nestedScopeCtx(stats)) {
if stats.exists(_.isInstanceOf[Import]) then
// need to transform stats and expr together to account for import visibility
val stats1 = transformStats(stats :+ expr, ctx.owner)
cpy.Block(tree)(stats1.init, stats1.last)
else super.transform(tree)
}
super.transform(tree)(using nestedScopeCtx(stats))
case tree: DefDef =>
inContext(localCtx(tree)) {
cpy.DefDef(tree)(
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Trees {
cpy.NamedArg(tree)(name, transform(arg))
case Assign(lhs, rhs) =>
cpy.Assign(tree)(transform(lhs), transform(rhs))
case Block(stats, expr) =>
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
case blk: Block =>
transformBlock(blk)
case If(cond, thenp, elsep) =>
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
case Closure(env, meth, tpt) =>
Expand Down Expand Up @@ -1489,6 +1489,8 @@ object Trees {

def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
transform(trees)
def transformBlock(blk: Block)(using Context): Block =
cpy.Block(blk)(transformStats(blk.stats, ctx.owner), transform(blk.expr))
def transform(trees: List[Tree])(using Context): List[Tree] =
flatten(trees mapConserve (transform(_)))
def transformSub[Tr <: Tree](tree: Tr)(using Context): Tr =
Expand Down
23 changes: 17 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* - be tail-recursive where possible
* - don't re-allocate trees where nothing has changed
*/
inline def mapStatements(exprOwner: Symbol, inline op: Tree => Context ?=> Tree)(using Context): List[Tree] =
inline def mapStatements[T](
exprOwner: Symbol,
inline op: Tree => Context ?=> Tree,
inline wrapResult: List[Tree] => Context ?=> T)(using Context): T =
@tailrec
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): List[Tree] =
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T =
pending match
case stat :: rest =>
val statCtx = stat match
Expand All @@ -1182,8 +1185,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
case _ => buf += stat1
loop(buf, rest, rest)(using restCtx)
case nil =>
if mapped == null then unchanged
else mapped.prependToList(unchanged)
wrapResult(
if mapped == null then unchanged
else mapped.prependToList(unchanged))

loop(null, trees, trees)
end mapStatements
Expand All @@ -1195,8 +1199,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* - imports are reflected in the contexts of subsequent statements
*/
class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy):
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
trees.mapStatements(exprOwner, transform(_))
def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
trees.mapStatements(exprOwner, transform(_), wrapResult)
final override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
transformStats(trees, exprOwner, sameStats)
override def transformBlock(blk: Block)(using Context) =
transformStats(blk.stats, ctx.owner,
stats1 => ctx ?=> cpy.Block(blk)(stats1, transform(blk.expr)))

val sameStats: List[Tree] => Context ?=> List[Tree] = stats => stats

/** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments.
* Also drops Inline and Block with no statements.
Expand Down
50 changes: 35 additions & 15 deletions compiler/src/dotty/tools/dotc/transform/ForwardDepChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ object ForwardDepChecks:
}

/** A class to help in forward reference checking */
class LevelInfo(outerLevelAndIndex: LevelAndIndex, stats: List[Tree])(using Context)
class LevelInfo(val outer: OptLevelInfo, val owner: Symbol, stats: List[Tree])(using Context)
extends OptLevelInfo {
override val levelAndIndex: LevelAndIndex =
stats.foldLeft(outerLevelAndIndex, 0) {(mi, stat) =>
stats.foldLeft(outer.levelAndIndex, 0) {(mi, stat) =>
val (m, idx) = mi
val m1 = stat match {
case stat: MemberDef => m.updated(stat.symbol, (this, idx))
Expand Down Expand Up @@ -71,7 +71,7 @@ class ForwardDepChecks extends MiniPhase:

override def prepareForStats(trees: List[Tree])(using Context): Context =
if (ctx.owner.isTerm)
ctx.fresh.updateStore(LevelInfo, new LevelInfo(currentLevel.levelAndIndex, trees))
ctx.fresh.updateStore(LevelInfo, new LevelInfo(currentLevel, ctx.owner, trees))
else ctx

override def transformValDef(tree: ValDef)(using Context): ValDef =
Expand All @@ -89,19 +89,39 @@ class ForwardDepChecks extends MiniPhase:
tree
}

override def transformApply(tree: Apply)(using Context): Apply = {
if (isSelfConstrCall(tree)) {
assert(currentLevel.isInstanceOf[LevelInfo], s"${ctx.owner}/" + i"$tree")
val level = currentLevel.asInstanceOf[LevelInfo]
if (level.maxIndex > 0) {
// An implementation restriction to avoid VerifyErrors and lazyvals mishaps; see SI-4717
report.debuglog("refsym = " + level.refSym)
report.error("forward reference not allowed from self constructor invocation",
ctx.source.atSpan(level.refSpan))
}
}
/** Check that self constructor call does not contain references to vals or defs
* defined later in the secondary constructor's right hand side. This is tricky
* since the complete self constructor might itself be a block that originated from
* expanding named and default parameters. In that case we have to go outwards
* and find the enclosing expression that consists of that block. Test cases in
* {pos,neg}/complex-self-call.scala.
*/
private def checkSelfConstructorCall()(using Context): Unit =
// Find level info corresponding to constructor's RHS. This is the info of the
// outermost LevelInfo that has the constructor as owner.
def rhsLevelInfo(l: OptLevelInfo): OptLevelInfo = l match
case l: LevelInfo if l.owner == ctx.owner =>
rhsLevelInfo(l.outer) match
case l1: LevelInfo => l1
case _ => l
case _ =>
NoLevelInfo

rhsLevelInfo(currentLevel) match
case level: LevelInfo =>
if level.maxIndex > 0 then
report.debuglog("refsym = " + level.refSym.showLocated)
report.error("forward reference not allowed from self constructor invocation",
ctx.source.atSpan(level.refSpan))
case _ =>
assert(false, s"${ctx.owner.showLocated}")
end checkSelfConstructorCall

override def transformApply(tree: Apply)(using Context): Apply =
if (isSelfConstrCall(tree))
assert(ctx.owner.isConstructor)
checkSelfConstructorCall()
tree
}

override def transformNew(tree: New)(using Context): New = {
currentLevel.enterReference(tree.tpe.typeSymbol, tree.span)
Expand Down
17 changes: 13 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/MegaPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
}
case tree: Block =>
inContext(prepBlock(tree, start)(using outerCtx)) {
val stats = transformStats(tree.stats, ctx.owner, start)
val expr = transformTree(tree.expr, start)
goBlock(cpy.Block(tree)(stats, expr), start)
transformBlock(tree, start)
}
case tree: TypeApply =>
inContext(prepTypeApply(tree, start)(using outerCtx)) {
Expand Down Expand Up @@ -434,9 +432,20 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {

def transformStats(trees: List[Tree], exprOwner: Symbol, start: Int)(using Context): List[Tree] =
val nestedCtx = prepStats(trees, start)
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start))(using nestedCtx)
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start), stats1 => stats1)(using nestedCtx)
goStats(trees1, start)(using nestedCtx)

def transformBlock(tree: Block, start: Int)(using Context): Tree =
val nestedCtx = prepStats(tree.stats, start)
val block1 = tree.stats.mapStatements(ctx.owner,
transformTree(_, start),
stats1 => ctx ?=> {
val stats2 = goStats(stats1, start)(using nestedCtx)
val expr2 = transformTree(tree.expr, start)
cpy.Block(tree)(stats2, expr2)
})(using nestedCtx)
goBlock(block1, start)

def transformUnit(tree: Tree)(using Context): Tree = {
val nestedCtx = prepUnit(tree, 0)
val tree1 = transformTree(tree, 0)(using nestedCtx)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
throw ex
}

override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
try super.transformStats(trees, exprOwner)
override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
try super.transformStats(trees, exprOwner, wrapResult)
finally Checking.checkExperimentalImports(trees)

/** Transforms the rhs tree into a its default tree if it is in an `erased` val/def.
Expand Down
10 changes: 10 additions & 0 deletions tests/explicit-nulls/pos/unsafe-chain.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import java.nio.file.FileSystems
import java.util.ArrayList

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)
30 changes: 30 additions & 0 deletions tests/neg/complex-self-call.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// An example extracted from akka that demonstrated a spurious
// "forward reference not allowed from self constructor invocation" error.
class Resizer
class SupervisorStrategy
class Pool
object Pool:
def defaultSupervisorStrategy: SupervisorStrategy = ???
object Dispatchers:
def DefaultDispatcherId = ???
object Resizer:
def fromConfig(config: Config): Option[Resizer] = ???

class Config:
def getInt(str: String): Int = ???
def hasPath(str: String): Boolean = ???

final case class BroadcastPool(
nrOfInstances: Int,
val resizer: Option[Resizer] = None,
val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy,
val routerDispatcher: String = Dispatchers.DefaultDispatcherId,
val usePoolDispatcher: Boolean = false)
extends Pool:

def this(config: Config) =
this(
nrOfInstances = config.getInt("nr-of-instances"),
resizer = resiz, // error: forward reference not allowed from self constructor invocation
usePoolDispatcher = config.hasPath("pool-dispatcher"))
def resiz = Resizer.fromConfig(config)
29 changes: 29 additions & 0 deletions tests/pos/complex-self-call.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// An example extracted from akka that demonstrated a spurious
// "forward reference not allowed from self constructor invocation" error.
class Resizer
class SupervisorStrategy
class Pool
object Pool:
def defaultSupervisorStrategy: SupervisorStrategy = ???
object Dispatchers:
def DefaultDispatcherId = ???
object Resizer:
def fromConfig(config: Config): Option[Resizer] = ???

class Config:
def getInt(str: String): Int = ???
def hasPath(str: String): Boolean = ???

final case class BroadcastPool(
nrOfInstances: Int,
val resizer: Option[Resizer] = None,
val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy,
val routerDispatcher: String = Dispatchers.DefaultDispatcherId,
val usePoolDispatcher: Boolean = false)
extends Pool:

def this(config: Config) =
this(
nrOfInstances = config.getInt("nr-of-instances"),
resizer = Resizer.fromConfig(config),
usePoolDispatcher = config.hasPath("pool-dispatcher"))

0 comments on commit 915f4e8

Please sign in to comment.