Skip to content

Commit

Permalink
Fixup and finish List optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Nov 15, 2023
1 parent 90aea07 commit 2ac7c1c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 12 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class Definitions {
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
Expand Down
23 changes: 19 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ class ArrayApply extends MiniPhase {
case _ =>
tree

else if isListOrSeqModuleApply(tree.symbol) then
else if isSeqApply(tree) then
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
rest.elems.lengthIs < transformListApplyLimit =>
rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
val consed = rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
consed.cast(tree.tpe)

case _ =>
tree
Expand All @@ -52,8 +53,22 @@ class ArrayApply extends MiniPhase {
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
sym == defn.ListModule_apply || sym == defn.SeqModule_apply
private def isListApply(tree: Tree)(using Context): Boolean =
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.ListModule
|| sym == defn.ListModuleAlias
case _ => false

private def isSeqApply(tree: Tree)(using Context): Boolean =
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
case Select(qual, _) =>
val sym = qual.symbol
sym == defn.SeqModule
|| sym == defn.SeqModuleAlias
|| sym == defn.CollectionSeqType.symbol.companionModule
case _ => false

/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
Expand Down
65 changes: 58 additions & 7 deletions compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package dotty.tools.backend.jvm
package dotty.tools
package backend.jvm

import org.junit.Test
import org.junit.Assert._
Expand Down Expand Up @@ -161,26 +162,76 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}

@Test def testListApplyAvoidsIntermediateArray = {
val source =
"""
checkApplyAvoidsIntermediateArray("List"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray = {
checkApplyAvoidsIntermediateArray("Seq"):
"""import scala.collection.immutable.{ ::, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray2 = {
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
"""import scala.collection.immutable.{ ::, Seq, Nil }
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

@Test def testSeqApplyAvoidsIntermediateArray3 = {
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
|class Foo {
| def meth1: Seq[String] = Seq("1", "2", "3")
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
|}
""".stripMargin
}

def checkApplyAvoidsIntermediateArray(name: String)(source: String) = {
checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1)
val instructions1 = instructionsFromMethod(meth1) match
case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
instr :+ ret
case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
// List.apply[?A] doesn't, strictly, return List[?A],
// because it cascades to its definition on IterableFactory
// where it returns CC[A]. The erasure of that is Object,
// which is why Erasure's Typer adds a cast to compensate.
// If we drop that cast while optimising (because using
// the constructor for :: doesn't require the cast like
// List.apply did) then then cons construction chain will
// be typed as ::.
// Unfortunately the LUB of :: and Nil.type is Product
// instead of List, so a cast remains necessary,
// across whatever causes the lub, like `if` or `try` branches.
// Therefore if we dropping the cast may cause a needed cast
// to be necessary, we shouldn't drop the cast,
// which was only motivated by the assert here.
instr :+ ret
case instr => instr
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
"the List.apply method " +
s"the $name.apply method\n" +
diffInstructions(instructions1, instructions2))
}
}
Expand Down
14 changes: 13 additions & 1 deletion tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object Test:
counter += 1
counter.toString

def main(args: Array[String]): Unit =
def main(args: Array[String]): Unit =
//List.apply is subject to an optimisation in cleanup
//ensure that the arguments are evaluated in the currect order
// Rewritten to:
Expand All @@ -19,3 +19,15 @@ object Test:

val emptyList = List[Int]()
assert(emptyList == Nil)

// just assert it doesn't throw CCE to List
val queue = scala.collection.mutable.Queue[String]()

// test for the cast instruction described in checkApplyAvoidsIntermediateArray
def lub(b: Boolean): List[(String, String)] =
if b then List(("foo", "bar")) else Nil

// from minimising CI failure in oslib
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
def lub2(b: Boolean): Unit =
Seq(1) ++ (if (b) Seq(2) else Nil)

0 comments on commit 2ac7c1c

Please sign in to comment.