Skip to content

Commit

Permalink
SIP-62 - For comprehension improvements (#20522)
Browse files Browse the repository at this point in the history
Implementation for SIP-62.

### Summary of the changes

For more details read the committed markdown file here:
scala/improvement-proposals#79

This introduces improvements to `for` comprehensions in Scala to improve
usability and simplify desugaring. The changes are hidden behind a
language import `scala.language.experimental.betterFors`.
The main changes are:

1. **Starting `for` comprehensions with aliases**: 
   - **Current Syntax**:
     ```scala
     val a = 1
     for {
       b <- Some(2)
       c <- doSth(a)
     } yield b + c
     ```
   - **New Syntax**:
     ```scala
     for {
       a = 1
       b <- Some(2)
       c <- doSth(a)
     } yield b + c
     ```

2. **Simpler Desugaring for Pure Aliases**:
   - **Current Desugaring**:
     ```scala
     for {
       a <- doSth(arg)
       b = a
     } yield a + b
     ```
     Desugars to:
     ```scala
     doSth(arg).map { a =>
       val b = a
       (a, b)
     }.map { case (a, b) =>
       a + b
     }
     ```
   - **New Desugaring**: (where possible)
     ```scala
     doSth(arg).map { a =>
       val b = a
       a + b
     }
     ```

3. **Avoiding Redundant `map` Calls**:
   - **Current Desugaring**:
     ```scala
     for {
       a <- List(1, 2, 3)
     } yield a
     ```
     Desugars to:
     ```scala
     List(1, 2, 3).map(a => a)
     ```
   - **New Desugaring**:
     ```scala
     List(1, 2, 3)
     ```
  • Loading branch information
odersky authored Aug 1, 2024
2 parents 1b644f6 + 4bc0a4a commit e261fa2
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 25 deletions.
109 changes: 88 additions & 21 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D
import typer.{Namer, Checking}
import util.{Property, SourceFile, SourcePosition, SrcPos, Chars}
import config.{Feature, Config}
import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled}
import config.SourceVersion.*
import collection.mutable
import reporting.*
Expand Down Expand Up @@ -1815,46 +1816,81 @@ object desugar {
/** Create tree for for-comprehension `<for (enums) do body>` or
* `<for (enums) yield body>` where mapName and flatMapName are chosen
* corresponding to whether this is a for-do or a for-yield.
* The creation performs the following rewrite rules:
* If betterFors are enabled, the creation performs the following rewrite rules:
*
* 1.
* 1. if betterFors is enabled:
*
* for (P <- G) E ==> G.foreach (P => E)
* for () do E ==> E
* or
* for () yield E ==> E
*
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
* (Where empty for-comprehensions are excluded by the parser)
*
* 2.
*
* for (P <- G) yield E ==> G.map (P => E)
* for (P <- G) do E ==> G.foreach (P => E)
*
* Here and in the following (P => E) is interpreted as the function (P => E)
* if P is a variable pattern and as the partial function { case P => E } otherwise.
*
* 3.
*
* for (P <- G) yield P ==> G
*
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
*
* for (P <- G) yield E ==> G.map (P => E)
*
* Otherwise
*
* 4.
*
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
* ==>
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
*
* 4.
* 5.
*
* for (P <- G; E; ...) ...
* =>
* for (P <- G.filter (P => E); ...) ...
* for (P <- G; if E; ...) ...
* ==>
* for (P <- G.withFilter (P => E); ...) ...
*
* 5. For any N:
* 6. For any N, if betterFors is enabled:
*
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
* for (P <- G; P_1 = E_1; ... P_N = E_N; P1 <- G1; ...) ...
* ==>
* for (TupleN(P_1, P_2, ... P_N) <-
* for (x_1 @ P_1 <- G) yield {
* val x_2 @ P_2 = E_2
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...))
*
* 7. For any N, if betterFors is enabled:
*
* for (P <- G; P_1 = E_1; ... P_N = E_N) ...
* ==>
* G.map (P => for (P_1 = E_1; ... P_N = E_N) ...)
*
* 8. For any N:
*
* for (P <- G; P_1 = E_1; ... P_N = E_N; ...)
* ==>
* for (TupleN(P, P_1, ... P_N) <-
* for (x @ P <- G) yield {
* val x_1 @ P_1 = E_2
* ...
* val x_N & P_N = E_N
* TupleN(x_1, ..., x_N)
* } ...)
* val x_N @ P_N = E_N
* TupleN(x, x_1, ..., x_N)
* }; if E; ...)
*
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
* and the variable constituting P_i is used instead of x_i
*
* 9. For any N, if betterFors is enabled:
*
* for (P_1 = E_1; ... P_N = E_N; ...)
* ==>
* {
* val x_N @ P_N = E_N
* for (...)
* }
*
* @param mapName The name to be used for maps (either map or foreach)
* @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
* @param enums The enumerators in the for expression
Expand Down Expand Up @@ -1963,7 +1999,7 @@ object desugar {
case GenCheckMode.FilterAlways => false // pattern was prefixed by `case`
case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr)
case GenCheckMode.Check => true
case GenCheckMode.Ignore => true
case GenCheckMode.Ignore | GenCheckMode.Filtered => true

/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
* matched against `rhs`.
Expand All @@ -1973,12 +2009,31 @@ object desugar {
Select(rhs, name)
}

def deepEquals(t1: Tree, t2: Tree): Boolean =
(unsplice(t1), unsplice(t2)) match
case (Ident(n1), Ident(n2)) => n1 == n2
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
case _ => false

enums match {
case Nil if betterForsEnabled => body
case (gen: GenFrom) :: Nil =>
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
if betterForsEnabled
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
case (gen: GenFrom) :: rest
if betterForsEnabled
&& rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => // possible aliases followed by a generator or end of for
val cont = makeFor(mapName, flatMapName, rest, body)
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
Expand All @@ -1997,8 +2052,20 @@ object desugar {
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
case (gen: GenFrom) :: test :: rest =>
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
val genFrom = GenFrom(gen.pat, filtered, if betterForsEnabled then GenCheckMode.Filtered else GenCheckMode.Ignore)
makeFor(mapName, flatMapName, genFrom :: rest, body)
case GenAlias(_, _) :: _ if betterForsEnabled =>
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
val pats = valeqs.map { case GenAlias(pat, _) => pat }
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
val (defpats, ids) = pats.map(makeIdPat).unzip
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
val mods = defpat match
case defTree: DefTree => defTree.mods
case _ => Modifiers()
makePatDef(valeq, mods, defpat, rhs)
}
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
case _ =>
EmptyTree //may happen for erroneous input
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {

/** An enum to control checking or filtering of patterns in GenFrom trees */
enum GenCheckMode {
case Ignore // neither filter nor check since filtering was done before
case Ignore // neither filter nor check since pattern is trivially irrefutable
case Filtered // neither filter nor check since filtering was done before
case Check // check that pattern is irrefutable
case CheckAndFilter // both check and filter (transitional period starting with 3.2)
case FilterNow // filter out non-matching elements if we are not in 3.2 or later
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ object Feature:
val modularity = experimental("modularity")
val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
val betterFors = experimental("betterFors")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
defn.languageExperimentalFeatures
Expand Down Expand Up @@ -67,7 +68,8 @@ object Feature:
(into, "Allow into modifier on parameter types"),
(namedTuples, "Allow named tuples"),
(modularity, "Enable experimental modularity features"),
(betterMatchTypeExtractors, "Enable better match type extractors")
(betterMatchTypeExtractors, "Enable better match type extractors"),
(betterFors, "Enable improvements in `for` comprehensions")
)

// legacy language features from Scala 2 that are no longer supported.
Expand Down Expand Up @@ -125,6 +127,8 @@ object Feature:
def clauseInterleavingEnabled(using Context) =
sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving)

def betterForsEnabled(using Context) = enabled(betterFors)

def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals)

def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)
Expand Down
18 changes: 17 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2894,7 +2894,11 @@ object Parsers {

/** Enumerators ::= Generator {semi Enumerator | Guard}
*/
def enumerators(): List[Tree] = generator() :: enumeratorsRest()
def enumerators(): List[Tree] =
if in.featureEnabled(Feature.betterFors) then
aliasesUntilGenerator() ++ enumeratorsRest()
else
generator() :: enumeratorsRest()

def enumeratorsRest(): List[Tree] =
if (isStatSep) {
Expand Down Expand Up @@ -2936,6 +2940,18 @@ object Parsers {
GenFrom(pat, subExpr(), checkMode)
}

def aliasesUntilGenerator(): List[Tree] =
if in.token == CASE then generator() :: Nil
else {
val pat = pattern1()
if in.token == EQUALS then
atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: {
if (isStatSep) in.nextToken()
aliasesUntilGenerator()
}
else generatorRest(pat, casePat = false) :: Nil
}

/** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr
* | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr
* | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr
Expand Down
6 changes: 6 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ object language:
@compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements")
object quotedPatternsWithPolymorphicFunctions

/** Experimental support for improvements in `for` comprehensions
*
* @see [[https://github.com/scala/improvement-proposals/pull/79]]
*/
@compileTimeOnly("`betterFors` can only be used at compile time in import statements")
object betterFors
end experimental

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
3 changes: 2 additions & 1 deletion project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ object MiMaFilters {
val ForwardsBreakingChanges: Map[String, Seq[ProblemFilter]] = Map(
// Additions that require a new minor version of the library
Build.mimaPreviousDottyVersion -> Seq(

ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.betterFors"),
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$betterFors$"),
),

// Additions since last LTS
Expand Down
12 changes: 12 additions & 0 deletions tests/run/better-fors.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
List((1,3), (1,4), (2,3), (2,4))
List((1,2,3), (1,2,4))
List((1,3), (1,4), (2,3), (2,4))
List((2,3), (2,4))
List((2,3), (2,4))
List((1,2), (2,4))
List(1, 2, 3)
List((2,3,6))
List(6)
List(3, 6)
List(6)
List(2)
105 changes: 105 additions & 0 deletions tests/run/better-fors.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import scala.language.experimental.betterFors

def for1 =
for {
a = 1
b <- List(a, 2)
c <- List(3, 4)
} yield (b, c)

def for2 =
for
a = 1
b = 2
c <- List(3, 4)
yield (a, b, c)

def for3 =
for {
a = 1
b <- List(a, 2)
c = 3
d <- List(c, 4)
} yield (b, d)

def for4 =
for {
a = 1
b <- List(a, 2)
if b > 1
c <- List(3, 4)
} yield (b, c)

def for5 =
for {
a = 1
b <- List(a, 2)
c = 3
if b > 1
d <- List(c, 4)
} yield (b, d)

def for6 =
for {
a = 1
b = 2
c <- for {
x <- List(a, b)
y = x * 2
} yield (x, y)
} yield c

def for7 =
for {
a <- List(1, 2, 3)
} yield a

def for8 =
for {
a <- List(1, 2)
b = a + 1
if b > 2
c = b * 2
if c < 8
} yield (a, b, c)

def for9 =
for {
a <- List(1, 2)
b = a * 2
if b > 2
} yield a + b

def for10 =
for {
a <- List(1, 2)
b = a * 2
} yield a + b

def for11 =
for {
a <- List(1, 2)
b = a * 2
if b > 2 && b % 2 == 0
} yield a + b

def for12 =
for {
a <- List(1, 2)
if a > 1
} yield a

object Test extends App {
println(for1)
println(for2)
println(for3)
println(for4)
println(for5)
println(for6)
println(for7)
println(for8)
println(for9)
println(for10)
println(for11)
println(for12)
}
3 changes: 3 additions & 0 deletions tests/run/fors.check
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ hello world
hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4
(2,1) (4,3)

testTailrec
List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c)))

testGivens
123
456
Expand Down
Loading

0 comments on commit e261fa2

Please sign in to comment.