Skip to content

Commit

Permalink
Thread context through block in transforms correctly and efficiently
Browse files Browse the repository at this point in the history
Fixes #14319
  • Loading branch information
odersky committed Feb 20, 2022
1 parent d09dd2a commit a6732dc
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
8 changes: 1 addition & 7 deletions compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,7 @@ class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts {
override def transform(tree: Tree)(using Context): Tree = {
try tree match {
case Block(stats, expr) =>
inContext(nestedScopeCtx(stats)) {
if stats.exists(_.isInstanceOf[Import]) then
// need to transform stats and expr together to account for import visibility
val stats1 = transformStats(stats :+ expr, ctx.owner)
cpy.Block(tree)(stats1.init, stats1.last)
else super.transform(tree)
}
super.transform(tree)(using nestedScopeCtx(stats))
case tree: DefDef =>
inContext(localCtx(tree)) {
cpy.DefDef(tree)(
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Trees {
cpy.NamedArg(tree)(name, transform(arg))
case Assign(lhs, rhs) =>
cpy.Assign(tree)(transform(lhs), transform(rhs))
case Block(stats, expr) =>
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
case blk: Block =>
transformBlock(blk)
case If(cond, thenp, elsep) =>
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
case Closure(env, meth, tpt) =>
Expand Down Expand Up @@ -1489,6 +1489,8 @@ object Trees {

def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
transform(trees)
def transformBlock(blk: Block)(using Context): Block =
cpy.Block(blk)(transformStats(blk.stats, ctx.owner), transform(blk.expr))
def transform(trees: List[Tree])(using Context): List[Tree] =
flatten(trees mapConserve (transform(_)))
def transformSub[Tr <: Tree](tree: Tr)(using Context): Tr =
Expand Down
23 changes: 17 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* - be tail-recursive where possible
* - don't re-allocate trees where nothing has changed
*/
inline def mapStatements(exprOwner: Symbol, inline op: Tree => Context ?=> Tree)(using Context): List[Tree] =
inline def mapStatements[T](
exprOwner: Symbol,
inline op: Tree => Context ?=> Tree,
inline wrapResult: List[Tree] => Context ?=> T)(using Context): T =
@tailrec
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): List[Tree] =
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T =
pending match
case stat :: rest =>
val statCtx = stat match
Expand All @@ -1182,8 +1185,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
case _ => buf += stat1
loop(buf, rest, rest)(using restCtx)
case nil =>
if mapped == null then unchanged
else mapped.prependToList(unchanged)
wrapResult(
if mapped == null then unchanged
else mapped.prependToList(unchanged))

loop(null, trees, trees)
end mapStatements
Expand All @@ -1195,8 +1199,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* - imports are reflected in the contexts of subsequent statements
*/
class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy):
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
trees.mapStatements(exprOwner, transform(_))
def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
trees.mapStatements(exprOwner, transform(_), wrapResult)
final override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
transformStats(trees, exprOwner, sameStats)
override def transformBlock(blk: Block)(using Context) =
transformStats(blk.stats, ctx.owner,
stats1 => ctx ?=> cpy.Block(blk)(stats1, transform(blk.expr)))

val sameStats: List[Tree] => Context ?=> List[Tree] = stats => stats

/** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments.
* Also drops Inline and Block with no statements.
Expand Down
17 changes: 13 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/MegaPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
}
case tree: Block =>
inContext(prepBlock(tree, start)(using outerCtx)) {
val stats = transformStats(tree.stats, ctx.owner, start)
val expr = transformTree(tree.expr, start)
goBlock(cpy.Block(tree)(stats, expr), start)
transformBlock(tree, start)
}
case tree: TypeApply =>
inContext(prepTypeApply(tree, start)(using outerCtx)) {
Expand Down Expand Up @@ -434,9 +432,20 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {

def transformStats(trees: List[Tree], exprOwner: Symbol, start: Int)(using Context): List[Tree] =
val nestedCtx = prepStats(trees, start)
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start))(using nestedCtx)
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start), stats1 => stats1)(using nestedCtx)
goStats(trees1, start)(using nestedCtx)

def transformBlock(tree: Block, start: Int)(using Context): Tree =
val nestedCtx = prepStats(tree.stats, start)
val block1 = tree.stats.mapStatements(ctx.owner,
transformTree(_, start),
stats1 => ctx ?=> {
val stats2 = goStats(stats1, start)(using nestedCtx)
val expr2 = transformTree(tree.expr, start)
cpy.Block(tree)(stats2, expr2)
})(using nestedCtx)
goBlock(block1, start)

def transformUnit(tree: Tree)(using Context): Tree = {
val nestedCtx = prepUnit(tree, 0)
val tree1 = transformTree(tree, 0)(using nestedCtx)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
throw ex
}

override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
try super.transformStats(trees, exprOwner)
override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
try super.transformStats(trees, exprOwner, wrapResult)
finally Checking.checkExperimentalImports(trees)

/** Transforms the rhs tree into a its default tree if it is in an `erased` val/def.
Expand Down
10 changes: 10 additions & 0 deletions tests/explicit-nulls/pos/unsafe-chain.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import java.nio.file.FileSystems
import java.util.ArrayList

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)

0 comments on commit a6732dc

Please sign in to comment.