Skip to content

Commit

Permalink
improvement: Match completions on union types
Browse files Browse the repository at this point in the history
Exhaustive match completion now works on union of case classes.
Also improves the labels for match completion.
  • Loading branch information
jkciesluk committed Feb 16, 2024
1 parent ad7e280 commit 659e1d7
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Symbols.Symbol

object MetalsSealedDesc:
// For scala 3.0.0 and 3.0.1 method `sealedStrictDescendants` is not available
def sealedStrictDescendants(sym: Symbol)(using Context): List[Symbol] =
Nil
// For scala 3.0.0 and 3.0.1 method `sealedDescendants` is not available
def sealedDescendants(sym: Symbol)(using Context): List[Symbol] = sym
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.Symbol

object MetalsSealedDesc:
def sealedStrictDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedStrictDescendants.filter(child =>
def sealedDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedDescendants.filter(child =>
!(child.is(Sealed) && (child.is(Abstract) || child.is(Trait)))
&& (child.isPublic || child.isAccessibleFrom(sym.info)) &&
child.name != tpnme.LOCAL_CHILD
&& child.maybeOwner.exists
&& (child.isPublic || child.isAccessibleFrom(sym.info))
&& child.name != tpnme.LOCAL_CHILD
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.Symbol

object MetalsSealedDesc:
def sealedStrictDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedStrictDescendants.filter(child =>
def sealedDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedDescendants.filter(child =>
!(child.is(Sealed) && (child.is(Abstract) || child.is(Trait)))
&& (child.isPublic || child.isAccessibleFrom(sym.info)) &&
child.name != tpnme.LOCAL_CHILD
&& child.maybeOwner.exists
&& (child.isPublic || child.isAccessibleFrom(sym.info))
&& child.name != tpnme.LOCAL_CHILD
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.Symbol

object MetalsSealedDesc:
def sealedStrictDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedStrictDescendants.filter(child =>
def sealedDescendants(sym: Symbol)(using Context): List[Symbol] =
sym.sealedDescendants.filter(child =>
!(child.is(Sealed) && (child.is(Abstract) || child.is(Trait)))
&& (child.isPublic || child.isAccessibleFrom(sym.info)) &&
child.name != tpnme.LOCAL_CHILD
&& child.maybeOwner.exists
&& (child.isPublic || child.isAccessibleFrom(sym.info))
&& child.name != tpnme.LOCAL_CHILD
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import scala.collection.JavaConverters.*
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import scala.meta.internal.metals.ReportContext
import scala.meta.internal.mtags.MtagsEnrichments.*
import scala.meta.internal.pc.AutoImports.AutoImportsGenerator
import scala.meta.internal.pc.AutoImports.SymbolImport
import scala.meta.internal.pc.MetalsInteractive.*
import scala.meta.internal.pc.printer.MetalsPrinter
import scala.meta.internal.pc.printer.MetalsPrinter.IncludeDefaultParam
import scala.meta.pc.PresentationCompilerConfig
import scala.meta.pc.SymbolSearch

Expand All @@ -23,7 +26,6 @@ import dotty.tools.dotc.core.Symbols.NoSymbol
import dotty.tools.dotc.core.Symbols.Symbol
import dotty.tools.dotc.core.Types.AndType
import dotty.tools.dotc.core.Types.ClassInfo
import dotty.tools.dotc.core.Types.NoType
import dotty.tools.dotc.core.Types.OrType
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.core.Types.TypeRef
Expand Down Expand Up @@ -57,7 +59,7 @@ object CaseKeywordCompletion:
patternOnly: Option[String] = None,
hasBind: Boolean = false,
includeExhaustive: Option[NewLineOptions] = None,
): List[CompletionValue] =
)(using ReportContext): List[CompletionValue] =
import indexedContext.ctx
val definitions = indexedContext.ctx.definitions
val clientSupportsSnippets = config.isCompletionSnippetsEnabled()
Expand All @@ -67,6 +69,11 @@ object CaseKeywordCompletion:
patternOnly,
hasBind,
)
val printer = MetalsPrinter.standard(
indexedContext,
search,
IncludeDefaultParam.Never,
)

val selTpe = selector match
case EmptyTree =>
Expand Down Expand Up @@ -140,14 +147,6 @@ object CaseKeywordCompletion:
result += symImport
end visit

// Step 0: case for selector type
selectorSym.info match
case NoType => ()
case _ =>
if !(selectorSym.is(Sealed) &&
(selectorSym.is(Abstract) || selectorSym.is(Trait)))
then visit((autoImportsGen.inferSymbolImport(selectorSym)))

// Step 1: walk through scope members.
def isValid(sym: Symbol) = !tpes(sym) &&
!isBottom(sym) &&
Expand Down Expand Up @@ -200,8 +199,7 @@ object CaseKeywordCompletion:
search,
)
sealedMembers match
case Nil => caseItems
case (_, label) :: tail =>
case (_, label) :: tail if tail.length > 0 =>
val (newLine, addIndent) =
if moveToNewLine then ("\n\t", "\t") else ("", "")
val insertText = Some(
Expand All @@ -222,9 +220,10 @@ object CaseKeywordCompletion:
s"case (exhaustive)",
insertText,
importEdit.toList,
s" ${selectorSym.decodedName} (${res.length} cases)",
s" ${printer.tpe(selTpe)} (${res.length} cases)",
)
exhaustive :: caseItems
case _ => caseItems
end match
case None => caseItems
end match
Expand All @@ -249,14 +248,22 @@ object CaseKeywordCompletion:
search: SymbolSearch,
autoImportsGen: AutoImportsGenerator,
noIndent: Boolean,
): List[CompletionValue] =
)(using ReportContext): List[CompletionValue] =
import indexedContext.ctx
val clientSupportsSnippets = config.isCompletionSnippetsEnabled()

val printer = MetalsPrinter.standard(
indexedContext,
search,
IncludeDefaultParam.Never,
)

val completionGenerator = CompletionValueGenerator(
completionPos,
clientSupportsSnippets,
)

val tpeStr = printer.tpe(selector.tpe.widen.metalsDealias.bounds.hi)
val tpe = selector.tpe.widen.metalsDealias.bounds.hi match
case tr @ TypeRef(_, _) => tr.underlying
case t => t
Expand Down Expand Up @@ -302,7 +309,7 @@ object CaseKeywordCompletion:
"match (exhaustive)",
insertText,
importEdit.toList,
s" ${tpe.typeSymbol.decodedName} (${labels.length} cases)",
s" ${tpeStr} (${labels.length} cases)",
)
List(basicMatch, exhaustive)
completions
Expand Down Expand Up @@ -351,19 +358,20 @@ object CaseKeywordCompletion:
* because `A <:< (B & C) == false`.
*/
def isExhaustiveMember(sym: Symbol): Boolean =
val symTpe = sym.info match
sym.info match
case cl: ClassInfo =>
cl.parents
val parentsMerged = cl.parents
.reduceLeftOption((tp1, tp2) => tp1.&(tp2))
.getOrElse(sym.info)
case simple => simple
symTpe <:< tpe

cl.selfType <:< tpe || parentsMerged <:< tpe
case simple => simple <:< tpe

val parents = getParentTypes(tpe, List.empty)
parents.toList.map { parent =>
// There is an issue in Dotty, `sealedStrictDescendants` ends in an exception for java enums. https://github.com/lampepfl/dotty/issues/15908
if parent.isAllOf(JavaEnumTrait) then parent.children
else MetalsSealedDesc.sealedStrictDescendants(parent)
else MetalsSealedDesc.sealedDescendants(parent)
} match
case Nil => Nil
case subcls :: Nil => subcls
Expand Down
24 changes: 20 additions & 4 deletions tests/cross/src/test/scala/tests/pc/CompletionCaseSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class CompletionCaseSuite extends BaseCompletionSuite {
|""".stripMargin,
compat = Map("3" -> """|case None => scala
|case Some(value) => scala
|case (exhaustive) Option (2 cases)
|case (exhaustive) Option[Int] (2 cases)
|""".stripMargin)
)

Expand Down Expand Up @@ -338,7 +338,7 @@ class CompletionCaseSuite extends BaseCompletionSuite {
|""".stripMargin,
compat = Map("3" -> """|case None => scala
|case Some(value) => scala
|case (exhaustive) Option (2 cases)
|case (exhaustive) Option[Int] (2 cases)
|""".stripMargin)
)

Expand All @@ -356,7 +356,7 @@ class CompletionCaseSuite extends BaseCompletionSuite {
|""".stripMargin,
compat = Map("3" -> """|case None => scala
|case Some(value) => scala
|case (exhaustive) Option (2 cases)
|case (exhaustive) Option[Int] (2 cases)
|""".stripMargin)
)

Expand Down Expand Up @@ -581,7 +581,7 @@ class CompletionCaseSuite extends BaseCompletionSuite {
)

check(
"private-member".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)),
"private-member1".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)),
"""
|package example
|import scala.collection.immutable.Vector
Expand Down Expand Up @@ -793,4 +793,20 @@ class CompletionCaseSuite extends BaseCompletionSuite {
|""".stripMargin
)

check(
"union-type".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)),
"""
|case class Foo(a: Int)
|case class Bar(b: Int)
|
|object O {
| val x: Foo | Bar = ???
| val y = List(x).map{ ca@@ }
|}""".stripMargin,
"""|case Bar(b) => union-type
|case Foo(a) => union-type
|case (exhaustive) Foo | Bar (2 cases)
|""".stripMargin
)

}
35 changes: 19 additions & 16 deletions tests/cross/src/test/scala/tests/pc/CompletionMatchSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ class CompletionMatchSuite extends BaseCompletionSuite {
|}""".stripMargin,
"""|match
|match (exhaustive) Option[Int] (2 cases)
|""".stripMargin,
compat = Map(
"3" -> """|match
|match (exhaustive) Option (2 cases)
|""".stripMargin
)
|""".stripMargin
)

check(
Expand All @@ -34,12 +29,7 @@ class CompletionMatchSuite extends BaseCompletionSuite {
|}""".stripMargin,
"""|match
|match (exhaustive) Option[Int] (2 cases)
|""".stripMargin,
compat = Map(
"3" -> """|match
|match (exhaustive) Option (2 cases)
|""".stripMargin
)
|""".stripMargin
)

// In Scala3 it's allowed to write xxx.match
Expand All @@ -52,7 +42,7 @@ class CompletionMatchSuite extends BaseCompletionSuite {
"",
compat = Map(
"3" -> """|match
|match (exhaustive) Option (2 cases)
|match (exhaustive) Option[Int] (2 cases)
|""".stripMargin
)
)
Expand Down Expand Up @@ -475,9 +465,7 @@ class CompletionMatchSuite extends BaseCompletionSuite {
"""|case (exhaustive) Option[A] (2 cases)
|""".stripMargin,
compat = Map(
"3" ->
"""|case (exhaustive) Option (2 cases)
|""".stripMargin
"3" -> "case (exhaustive) Option[Int] (2 cases)"
),
filter = _.contains("exhaustive")
)
Expand Down Expand Up @@ -853,4 +841,19 @@ class CompletionMatchSuite extends BaseCompletionSuite {
filter = _.contains("exhaustive")
)

check(
"union-type".tag(IgnoreScala2.and(IgnoreForScala3CompilerPC)),
"""
|case class Foo(a: Int)
|case class Bar(b: Int)
|
|object O {
| val x: Foo | Bar = ???
| val y = x match@@
|}""".stripMargin,
"""|match
|match (exhaustive) Foo | Bar (2 cases)
|""".stripMargin
)

}

0 comments on commit 659e1d7

Please sign in to comment.