From a6732dc99f3cd1e3261e4b7a9dff60f53b32728f Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 20 Feb 2022 15:08:44 +0100 Subject: [PATCH] Thread context through block in transforms correctly and efficiently Fixes #14319 --- .../tools/dotc/ast/TreeMapWithImplicits.scala | 8 +------ compiler/src/dotty/tools/dotc/ast/Trees.scala | 6 +++-- compiler/src/dotty/tools/dotc/ast/tpd.scala | 23 ++++++++++++++----- .../tools/dotc/transform/MegaPhase.scala | 17 ++++++++++---- .../tools/dotc/transform/PostTyper.scala | 4 ++-- tests/explicit-nulls/pos/unsafe-chain.scala | 10 ++++++++ 6 files changed, 47 insertions(+), 21 deletions(-) create mode 100644 tests/explicit-nulls/pos/unsafe-chain.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala index 3f4ff4687787..999a80c5e446 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala @@ -48,13 +48,7 @@ class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts { override def transform(tree: Tree)(using Context): Tree = { try tree match { case Block(stats, expr) => - inContext(nestedScopeCtx(stats)) { - if stats.exists(_.isInstanceOf[Import]) then - // need to transform stats and expr together to account for import visibility - val stats1 = transformStats(stats :+ expr, ctx.owner) - cpy.Block(tree)(stats1.init, stats1.last) - else super.transform(tree) - } + super.transform(tree)(using nestedScopeCtx(stats)) case tree: DefDef => inContext(localCtx(tree)) { cpy.DefDef(tree)( diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 7aa4491c31de..9a371e946e08 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1399,8 +1399,8 @@ object Trees { cpy.NamedArg(tree)(name, transform(arg)) case Assign(lhs, rhs) => cpy.Assign(tree)(transform(lhs), transform(rhs)) - case Block(stats, expr) => - cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr)) + case blk: Block => + transformBlock(blk) case If(cond, thenp, elsep) => cpy.If(tree)(transform(cond), transform(thenp), transform(elsep)) case Closure(env, meth, tpt) => @@ -1489,6 +1489,8 @@ object Trees { def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = transform(trees) + def transformBlock(blk: Block)(using Context): Block = + cpy.Block(blk)(transformStats(blk.stats, ctx.owner), transform(blk.expr)) def transform(trees: List[Tree])(using Context): List[Tree] = flatten(trees mapConserve (transform(_))) def transformSub[Tr <: Tree](tree: Tr)(using Context): Tr = diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 396d85a8c58a..513557a9b40a 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1157,9 +1157,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { * - be tail-recursive where possible * - don't re-allocate trees where nothing has changed */ - inline def mapStatements(exprOwner: Symbol, inline op: Tree => Context ?=> Tree)(using Context): List[Tree] = + inline def mapStatements[T]( + exprOwner: Symbol, + inline op: Tree => Context ?=> Tree, + inline wrapResult: List[Tree] => Context ?=> T)(using Context): T = @tailrec - def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): List[Tree] = + def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T = pending match case stat :: rest => val statCtx = stat match @@ -1182,8 +1185,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { case _ => buf += stat1 loop(buf, rest, rest)(using restCtx) case nil => - if mapped == null then unchanged - else mapped.prependToList(unchanged) + wrapResult( + if mapped == null then unchanged + else mapped.prependToList(unchanged)) loop(null, trees, trees) end mapStatements @@ -1195,8 +1199,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { * - imports are reflected in the contexts of subsequent statements */ class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy): - override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = - trees.mapStatements(exprOwner, transform(_)) + def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T = + trees.mapStatements(exprOwner, transform(_), wrapResult) + final override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = + transformStats(trees, exprOwner, sameStats) + override def transformBlock(blk: Block)(using Context) = + transformStats(blk.stats, ctx.owner, + stats1 => ctx ?=> cpy.Block(blk)(stats1, transform(blk.expr))) + + val sameStats: List[Tree] => Context ?=> List[Tree] = stats => stats /** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments. * Also drops Inline and Block with no statements. diff --git a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala index 56342322824c..8de60186b26c 100644 --- a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala +++ b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala @@ -296,9 +296,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase { } case tree: Block => inContext(prepBlock(tree, start)(using outerCtx)) { - val stats = transformStats(tree.stats, ctx.owner, start) - val expr = transformTree(tree.expr, start) - goBlock(cpy.Block(tree)(stats, expr), start) + transformBlock(tree, start) } case tree: TypeApply => inContext(prepTypeApply(tree, start)(using outerCtx)) { @@ -434,9 +432,20 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase { def transformStats(trees: List[Tree], exprOwner: Symbol, start: Int)(using Context): List[Tree] = val nestedCtx = prepStats(trees, start) - val trees1 = trees.mapStatements(exprOwner, transformTree(_, start))(using nestedCtx) + val trees1 = trees.mapStatements(exprOwner, transformTree(_, start), stats1 => stats1)(using nestedCtx) goStats(trees1, start)(using nestedCtx) + def transformBlock(tree: Block, start: Int)(using Context): Tree = + val nestedCtx = prepStats(tree.stats, start) + val block1 = tree.stats.mapStatements(ctx.owner, + transformTree(_, start), + stats1 => ctx ?=> { + val stats2 = goStats(stats1, start)(using nestedCtx) + val expr2 = transformTree(tree.expr, start) + cpy.Block(tree)(stats2, expr2) + })(using nestedCtx) + goBlock(block1, start) + def transformUnit(tree: Tree)(using Context): Tree = { val nestedCtx = prepUnit(tree, 0) val tree1 = transformTree(tree, 0)(using nestedCtx) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index bfcf6cc6e5bf..926600ebbdd4 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -458,8 +458,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase throw ex } - override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = - try super.transformStats(trees, exprOwner) + override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T = + try super.transformStats(trees, exprOwner, wrapResult) finally Checking.checkExperimentalImports(trees) /** Transforms the rhs tree into a its default tree if it is in an `erased` val/def. diff --git a/tests/explicit-nulls/pos/unsafe-chain.scala b/tests/explicit-nulls/pos/unsafe-chain.scala new file mode 100644 index 000000000000..76c80d0c53fe --- /dev/null +++ b/tests/explicit-nulls/pos/unsafe-chain.scala @@ -0,0 +1,10 @@ +import java.nio.file.FileSystems +import java.util.ArrayList + +def directorySeparator: String = + import scala.language.unsafeNulls + FileSystems.getDefault().getSeparator() + +def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String = + import scala.language.unsafeNulls + xs.get(0).get(0).get(0) \ No newline at end of file