Skip to content

Commit

Permalink
List(...) optimization to avoid intermediate array (closes scala#17035)
Browse files Browse the repository at this point in the history
  • Loading branch information
KuceraMartin authored and Decel committed Oct 30, 2023
1 parent 231ca72 commit 708b55f
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 11 deletions.
20 changes: 11 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,15 @@ class Definitions {
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
})

@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
def ListType: TypeRef = ListClass.typeRef
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
def NilType: TermRef = NilModule.termRef
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
def ConsType: TypeRef = ConsClass.typeRef
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")

@tu lazy val SingletonClass: ClassSymbol =
// needed as a synthetic class because Scala 2.x refers to it in classfiles
Expand All @@ -541,7 +542,8 @@ class Definitions {
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq")
@tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply)


@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
Expand Down
27 changes: 25 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ class ArrayApply extends MiniPhase {

override def description: String = ArrayApply.description

private var transformListApplyLimit = 8

private def reducingTransformListApply[A](depth: Int)(body: => A): A = {
val saved = transformListApplyLimit
transformListApplyLimit -= depth
try body
finally transformListApplyLimit = saved
}

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
if isArrayModuleApply(tree.symbol) then
tree.args match {
tree.args match
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit
Expand All @@ -35,14 +44,28 @@ class ArrayApply extends MiniPhase {

case _ =>
tree
}

else if isListOrSeqModuleApply(tree.symbol) then
tree.args match
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
rest.elems.lengthIs < transformListApplyLimit =>
rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) =>
tpd.New(defn.ConsType, List(elem, acc))

case _ =>
tree

else tree

private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
sym.name == nme.apply
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))

private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
sym == defn.ListModule_apply || sym == defn.SeqModule_apply

/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
Expand Down
25 changes: 25 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,29 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
}
}

@Test def testListApplyAvoidsIntermediateArray = {
val source =
"""
|class Foo {
| def meth1: List[String] = List("1", "2", "3")
| def meth2: List[String] =
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
|}
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth1 = getMethod(clsNode, "meth1")
val meth2 = getMethod(clsNode, "meth2")

val instructions1 = instructionsFromMethod(meth1)
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
"the List.apply method " +
diffInstructions(instructions1, instructions2))
}
}

}
21 changes: 21 additions & 0 deletions tests/run/list-apply-eval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
object Test:

var counter = 0

def next =
counter += 1
counter.toString

def main(args: Array[String]): Unit =
//List.apply is subject to an optimisation in cleanup
//ensure that the arguments are evaluated in the currect order
// Rewritten to:
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
val myList = List(next, next, next)
assert(myList == List("1", "2", "3"), myList)

val mySeq = Seq(next, next, next)
assert(mySeq == Seq("4", "5", "6"), mySeq)

val emptyList = List[Int]()
assert(emptyList == Nil)

0 comments on commit 708b55f

Please sign in to comment.