From dd63d5d4167e60088df13f020add31110724d9a3 Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Tue, 31 May 2022 07:06:12 -0700 Subject: [PATCH] AvoidInfix: fix isWrapped, look for external paren --- .../org/scalafmt/rewrite/AvoidInfix.scala | 22 ++++++++++++++----- .../scala/org/scalafmt/rewrite/Rewrite.scala | 4 ++-- .../org/scalafmt/util/TokenTraverser.scala | 16 +++++++++++--- .../test/resources/rewrite/AvoidInfix.stat | 6 ++--- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala index 20d0a61478..7adba3cfb3 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala @@ -45,7 +45,7 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { builder += TokenPatch.AddLeft(opHead, ".", keepTok = true) val opLast = op.tokens.last - val opNextOpt = ctx.tokenTraverser.nextNonTrivialToken(opLast) + val opNextOpt = nextNonTrivial(opLast) def moveOpenDelim(prev: Token, open: Token): Unit = { // move delimiter (before comment or newline) @@ -65,7 +65,7 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { rb <- ctx.getMatchingOpt(lb) rbNext <- { moveOpenDelim(opLast, lb) - ctx.tokenTraverser.nextNonTrivialToken(rb) + nextNonTrivial(rb) } } yield (rb, rbNext) // move the left paren if enclosed, else enclose @@ -129,9 +129,21 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { }) } - private def isWrapped(t: Tree): Boolean = t.tokens.head match { - case h: Token.LeftParen => ctx.getMatchingOpt(h).contains(t.tokens.last) - case _ => false + private def isWrapped(t: Tree): Boolean = { + val head = t.tokens.head + val last = t.tokens.last + isMatching(head, last) || + ctx.tokenTraverser.prevNonTrivialToken(head).exists { + isMatching(_, nextNonTrivial(last).orNull) + } } + @inline + private def isMatching(head: Token, last: => Token): Boolean = + head.is[Token.LeftParen] && ctx.isMatching(head, last) + + @inline + private def nextNonTrivial(token: Token): Option[Token] = + ctx.tokenTraverser.nextNonTrivialToken(token) + } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala index bfba397ba6..b8bf1453fe 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala @@ -31,8 +31,8 @@ case class RewriteCtx( @inline def getMatchingOpt(a: Token): Option[Token] = matchingParens.get(TokenOps.hash(a)) - @inline def isMatching(a: Token, b: Token) = - getMatchingOpt(a).contains(b) + @inline def isMatching(a: Token, b: => Token) = + getMatchingOpt(a).exists(_ eq b) @inline def getIndex(token: Token) = tokenTraverser.getIndex(token) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala index eaa8d67447..94dcfb445f 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala @@ -49,9 +49,6 @@ class TokenTraverser(tokens: Tokens, input: Input) { } } - def nextNonTrivialToken(token: Token): Option[Token] = - findAfter(token)(x => if (x.is[Trivia]) None else Some(true)) - def prevToken(token: Token): Token = { tok2idx.get(token) match { case Some(i) if i > 0 => tokens(i - 1) @@ -59,6 +56,12 @@ class TokenTraverser(tokens: Tokens, input: Input) { } } + def nextNonTrivialToken(token: Token): Option[Token] = + findAfter(token)(TokenTraverser.isTrivialPred) + + def prevNonTrivialToken(token: Token): Option[Token] = + findBefore(token)(TokenTraverser.isTrivialPred) + /** Find a token after the given one. The search stops when the predicate * returns Some value (or the end is reached). * @return @@ -119,3 +122,10 @@ class TokenTraverser(tokens: Tokens, input: Input) { } } + +object TokenTraverser { + + private def isTrivialPred(token: Token): Option[Boolean] = + if (token.is[Trivia]) None else Some(true) + +} diff --git a/scalafmt-tests/src/test/resources/rewrite/AvoidInfix.stat b/scalafmt-tests/src/test/resources/rewrite/AvoidInfix.stat index a027f26f8d..588748ee9a 100644 --- a/scalafmt-tests/src/test/resources/rewrite/AvoidInfix.stat +++ b/scalafmt-tests/src/test/resources/rewrite/AvoidInfix.stat @@ -399,14 +399,14 @@ rewrite.neverInfix.includeFilters = [ "[^*]+" ] === 1 foo (2 +: 3) / 4 >>> -1.foo(((2 +: 3))./(4)) +1.foo((2 +: 3)./(4)) <<< avoid double wrap 2 rewrite.neverInfix.includeFilters = [ "[^*]+" ] === 1 foo ((2) +: 3) / 4 >>> -1.foo((((2) +: 3))./(4)) +1.foo(((2) +: 3)./(4)) <<< rewrite wrapped placeholder foo baz (_) >>> -foo baz (_) +foo.baz(_)