diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala index 176f376e8b..6abfb66919 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala @@ -174,8 +174,9 @@ class FormatWriter(formatOps: FormatOps) { case b: Term.Block if TreeOps.getBlockSingleStat(b).exists { /* guard for statements requiring a wrapper block * "foo { x => y; z }" can't become "foo(x => y; z)" */ - case Term.Function(_, body) => - TreeOps.getTermSingleStat(body).isDefined + case f: Term.Function => + TreeOps.getTermSingleStat(f.body).isDefined && + !RedundantBraces.needParensAroundParams(f) case _ => true } => b.parent match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index e9a546b2b4..29a0662f3b 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -13,6 +13,18 @@ import org.scalafmt.util.TreeOps._ object RedundantBraces extends Rewrite { override def create(implicit ctx: RewriteCtx): RewriteSession = new RedundantBraces + + def needParensAroundParams(f: Term.Function): Boolean = + /* either we have parens or no type; multiple params or + * no params guarantee parens, so we look for type and + * parens only for a single param */ + f.params match { + case List(param) if param.decltpe.nonEmpty => + val leftParen = f.tokens.find(_.is[Token.LeftParen]) + !leftParen.exists(_.start <= param.tokens.head.start) + case _ => false + } + } /** @@ -193,13 +205,14 @@ class RedundantBraces(implicit ctx: RewriteCtx) extends RewriteSession { } } - private def okToRemoveBlockWithinApply(b: Term.Block): Boolean = { + private def okToRemoveBlockWithinApply(b: Term.Block): Boolean = getSingleStatIfLineSpanOk(b).exists { - case Term.Function(_, _: Term.Block) => - !settings.methodBodies // else the inner block will be rewritten, too + case f: Term.Function => + // don't rewrite block if the inner block will be rewritten, too + !(f.body.is[Term.Block] && settings.methodBodies) && + !RedundantBraces.needParensAroundParams(f) case _ => true } - } /** Some blocks look redundant but aren't */ private def shouldRemoveSingleStatBlock(b: Term.Block): Boolean = diff --git a/scalafmt-tests/src/test/resources/rewrite/RedundantBraces-ParenLambdas.stat b/scalafmt-tests/src/test/resources/rewrite/RedundantBraces-ParenLambdas.stat index 0ca9d2aaf4..7e3d3587b6 100644 --- a/scalafmt-tests/src/test/resources/rewrite/RedundantBraces-ParenLambdas.stat +++ b/scalafmt-tests/src/test/resources/rewrite/RedundantBraces-ParenLambdas.stat @@ -136,20 +136,18 @@ override def run(args: List[String]): IO[ExitCode] = program } >>> -test does not parse override def run(args: List[String]): IO[ExitCode] = Slf4jLogger .create[IO] - .flatMap(implicit logger: Logger[IO] => program) + .flatMap { implicit logger: Logger[IO] => program } <<< #1707 2: don't rewrite to parens if typed lambda param def a(b: B): C[D] = C[String].contramap[D] { i: D => b.format(i) } >>> -test does not parse -def a(formatter: DateTimeFormatter): b[Instant] = - C[String].contramap[D](i: D => b.format(i)) +def a(b: B): C[D] = + C[String].contramap[D] { i: D => b.format(i) } <<< #1707 3: rewrite to parens if typed lambda param with parens def a(b: B): C[D] = C[String].contramap[D] { (i: D) => @@ -172,6 +170,4 @@ danglingParentheses = false val a = b[IO]( { c: Int => if (d) e else f }, g) >>> -test does not parse -val a = b[IO]( - c: Int => if (d) e else f , g) +val a = b[IO]({ c: Int => if (d) e else f }, g)