Skip to content

Commit

Permalink
Make the completion logic easier to extend
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarchambault committed Mar 18, 2021
1 parent 2e05d9e commit c3f59a5
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 8 deletions.
16 changes: 8 additions & 8 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object Completion {
*
* Otherwise, provide no completion suggestion.
*/
private def completionMode(path: List[Tree], pos: SourcePosition): Mode =
def completionMode(path: List[Tree], pos: SourcePosition): Mode =
path match {
case (ref: RefTree) :: _ =>
if (ref.name.isTermName) Mode.Term
Expand All @@ -81,7 +81,7 @@ object Completion {
* Inspect `path` to determine the completion prefix. Only symbols whose name start with the
* returned prefix should be considered.
*/
private def completionPrefix(path: List[untpd.Tree], pos: SourcePosition): String =
def completionPrefix(path: List[untpd.Tree], pos: SourcePosition): String =
path match {
case (sel: untpd.ImportSelector) :: _ =>
completionPrefix(sel.imported :: Nil, pos)
Expand All @@ -100,7 +100,7 @@ object Completion {
}

/** Inspect `path` to determine the offset where the completion result should be inserted. */
private def completionOffset(path: List[Tree]): Int =
def completionOffset(path: List[Tree]): Int =
path match {
case (ref: RefTree) :: _ => ref.span.point
case _ => 0
Expand Down Expand Up @@ -134,7 +134,7 @@ object Completion {
* If several denotations share the same name, the type denotations appear before term denotations inside
* the same `Completion`.
*/
private def describeCompletions(completions: CompletionMap)(using Context): List[Completion] = {
def describeCompletions(completions: CompletionMap)(using Context): List[Completion] = {
completions
.toList.groupBy(_._1.toTermName) // don't distinguish between names of terms and types
.toList.map { (name, namedDenots) =>
Expand All @@ -153,7 +153,7 @@ object Completion {
*
* When there are multiple denotations, show their kinds.
*/
private def description(denots: List[SingleDenotation])(using Context): String =
def description(denots: List[SingleDenotation])(using Context): String =
denots match {
case denot :: Nil =>
if (denot.isType) denot.symbol.showFullName
Expand All @@ -174,7 +174,7 @@ object Completion {
* For the results of all `xyzCompletions` methods term names and type names are always treated as different keys in the same map
* and they never conflict with each other.
*/
private class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
/** Completions for terms and types that are currently in scope:
* the members of the current class, local definitions and the symbols that have been imported,
* recursively adding completions from outer scopes.
Expand Down Expand Up @@ -442,11 +442,11 @@ object Completion {
* The completion mode: defines what kinds of symbols should be included in the completion
* results.
*/
private class Mode(val bits: Int) extends AnyVal {
class Mode(val bits: Int) extends AnyVal {
def is(other: Mode): Boolean = (bits & other.bits) == other.bits
def |(other: Mode): Mode = new Mode(bits | other.bits)
}
private object Mode {
object Mode {
/** No symbol should be included */
val None: Mode = new Mode(0)

Expand Down
133 changes: 133 additions & 0 deletions compiler/test/dotty/tools/dotc/interactive/CustomCompletion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package dotty.tools.dotc.interactive

import dotty.tools.dotc.ast.tpd._
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Denotations.SingleDenotation
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.Names.{Name, termName}
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.{Symbol, defn}
import dotty.tools.dotc.core.TypeError
import dotty.tools.dotc.util.Chars.{isOperatorPart, isScalaLetter}
import dotty.tools.dotc.util.SourcePosition

object CustomCompletion {

def completions(
pos: SourcePosition,
dependencyCompleteOpt: Option[String => (Int, Seq[String])],
enableDeep: Boolean
)(using Context): (Int, List[Completion]) = {
val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
computeCompletions(pos, path, dependencyCompleteOpt, enableDeep)(using Interactive.contextOfPath(path))
}

def computeCompletions(
pos: SourcePosition,
path: List[Tree],
dependencyCompleteOpt: Option[String => (Int, Seq[String])],
enableDeep: Boolean
)(using Context): (Int, List[Completion]) = {
val mode = Completion.completionMode(path, pos)
val prefix = Completion.completionPrefix(path, pos)
val completer = new DeepCompleter(mode, prefix, pos)

var extra = List.empty[Completion]

val completions = path match {
case Select(qual, _) :: _ => completer.selectionCompletions(qual)
case Import(Ident(name), _) :: _ if name.decode.toString == "$ivy" && dependencyCompleteOpt.nonEmpty =>
val complete = dependencyCompleteOpt.get
val (pos, completions) = complete(prefix)
val input0 = prefix.take(pos)
extra ++= completions.distinct.toList
.map(s => Completion(label(termName(input0 + s)), "", Nil))
Map.empty
case Import(expr, _) :: _ => completer.directMemberCompletions(expr)
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => completer.directMemberCompletions(expr)
case _ =>
completer.scopeCompletions ++ {
if (enableDeep) completer.deepCompletions
else Nil
}
}

val describedCompletions = extra ++ describeCompletions(completions)
val offset = Completion.completionOffset(path)

(pos.span.start - prefix.length, describedCompletions)
}

private type CompletionMap = Map[Name, Seq[SingleDenotation]]

private def describeCompletions(completions: CompletionMap)(using Context): List[Completion] = {
completions
.toList.groupBy(_._1.toTermName) // don't distinguish between names of terms and types
.toList.map { (name, namedDenots) =>
val denots = namedDenots.flatMap(_._2)
val typesFirst = denots.sortWith((d1, d2) => d1.isType && !d2.isType)
val desc = Completion.description(typesFirst)
Completion(label(name), desc, typesFirst.map(_.symbol))
}
}

class DeepCompleter(mode: Completion.Mode, prefix: String, pos: SourcePosition) extends Completion.Completer(mode, prefix, pos):
def deepCompletions(using Context): Map[Name, Seq[SingleDenotation]] = {

def allMembers(s: Symbol) =
try s.info.allMembers
catch {
case _: dotty.tools.dotc.core.TypeError => Nil
}
def rec(t: Symbol): Seq[Symbol] = {
val children =
if (t.is(Package) || t.is(PackageVal) || t.is(PackageClass)) {
allMembers(t).map(_.symbol).filter(_ != t).flatMap(rec)
} else Nil

t +: children.toSeq
}

val syms = for {
member <- allMembers(defn.RootClass).map(_.symbol).toList
sym <- rec(member)
if sym.name.toString.startsWith(prefix)
} yield sym

syms.map(sym => (sym.fullName, List(sym: SingleDenotation))).toMap
}

private val bslash = '\\'
private val specialChars = Set('[', ']', '(', ')', '{', '}', '.', ',', ';')

def label(name: Name): String = {

def maybeQuote(name: Name, recurse: Boolean): String =
if (recurse && name.isTermName)
name.asTermName.qualToString(maybeQuote(_, true), maybeQuote(_, false))
else {
// initially adapted from
// https://github.com/scala/scala/blob/decbd53f1bde4600c8ff860f30a79f028a8e431d/src/reflect/scala/reflect/internal/Printers.scala#L573-L584
val decName = name.decode.toString
val hasSpecialChar = decName.exists { ch =>
specialChars(ch) || ch.isWhitespace
}
def isOperatorLike = (name.isOperatorName || decName.exists(isOperatorPart)) &&
decName.exists(isScalaLetter) &&
!decName.contains(bslash)
lazy val term = name.toTermName

val needsBackTicks = hasSpecialChar ||
isOperatorLike ||
nme.keywords(term) && term != nme.USCOREkw

if (needsBackTicks) s"`$decName`"
else decName
}

maybeQuote(name, true)
}
}

171 changes: 171 additions & 0 deletions compiler/test/dotty/tools/dotc/interactive/CustomCompletionTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package dotty.tools
package dotc.interactive

import dotc.ast.tpd
import dotc.{CompilationUnit, Compiler, Run}
import dotc.core.Contexts.Context
import dotc.core.Mode
import dotc.reporting.StoreReporter
import dotc.util.{SourceFile, SourcePosition}
import dotc.util.Spans.Span

import org.junit.Assert._
import org.junit.Test

class CustomCompletionTests extends DottyTest:

private def completions(
input: String,
dependencyCompleter: Option[String => (Int, Seq[String])] = None,
deep: Boolean = false,
extraDefinitions: String = ""
): (Int, Seq[Completion]) =
val prefix = extraDefinitions + """
object Wrapper {
val expr = {
"""
val suffix = """
}
}
"""

val allCode = prefix + input + suffix
val index = prefix.length + input.length

val run = new Run(
new Compiler,
initialCtx.fresh
.addMode(Mode.ReadPositions | Mode.Interactive)
// discard errors - comment out this line to print them in the console
.setReporter(new StoreReporter(null))
.setSetting(initialCtx.settings.YstopAfter, List("typer"))
)
val file = SourceFile.virtual("<completions>", allCode, maybeIncomplete = true)
given ctx: Context = run.runContext.withSource(file)
val unit = CompilationUnit(file)
ctx
.run
.compileUnits(unit :: Nil, ctx)

// ignoring compilation errors here - the input code
// to complete likely doesn't compile

unit.tpdTree = {
import tpd._
unit.tpdTree match {
case PackageDef(_, p) =>
p.reverseIterator.collectFirst {
case TypeDef(_, tmpl: Template) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd }
.getOrElse(sys.error("Unexpected tree shape"))
}
.getOrElse(sys.error("Unexpected tree shape"))
case _ => sys.error("Unexpected tree shape")
}
}
val ctx1 = ctx.fresh.setCompilationUnit(unit)
val srcPos = SourcePosition(file, Span(index))
val (offset0, completions) =
if (deep || dependencyCompleter.nonEmpty)
CustomCompletion.completions(srcPos, dependencyCompleteOpt = dependencyCompleter, enableDeep = deep)(using ctx1)
else
Completion.completions(srcPos)(using ctx1)
val offset = offset0 - prefix.length
(offset, completions)


@Test def simple(): Unit =
val prefix = "scala.collection.immutable."
val input = prefix + "Ma"

val (offset, completions0) = completions(input)
val labels = completions0.map(_.label)

assert(offset == prefix.length)
assert(labels.contains("Map"))

@Test def custom(): Unit =
val prefix = "import $ivy."
val input = prefix + "scala"

val dependencies = Seq(
"scalaCompiler",
"scalaLibrary",
"other"
)
val (offset, completions0) = completions(
input,
dependencyCompleter = Some { dep =>
val matches = dependencies.filter(_.startsWith(dep))
(0, matches)
}
)
val labels = completions0.map(_.label)

assert(offset == prefix.length)
assert(labels.contains("scalaCompiler"))
assert(labels.contains("scalaLibrary"))
assert(labels.length == 2)

@Test def backTicks(): Unit =
val prefix = "Foo."
val input = prefix + "a"

val extraDefinitions =
"""object Foo { def a1 = 2; def `a-b` = 3 }
|""".stripMargin
val (offset, completions0) = completions(
input,
extraDefinitions = extraDefinitions,
deep = true // Enables CustomCompleter
)
val labels = completions0.map(_.label)

assert(offset == prefix.length)
assert(labels.contains("a1"))
assert(labels.contains("`a-b`"))

@Test def backTicksDependencies(): Unit =
val prefix = "import $ivy."
val input = prefix + "`org.scala-lang:scala-`"

val dependencies = Seq(
"org.scala-lang:scala-compiler",
"org.scala-lang:scala-library",
"other"
)
val (offset, completions0) = completions(
input,
dependencyCompleter = Some { dep =>
val matches = dependencies.filter(_.startsWith(dep))
(0, matches)
}
)
val labels = completions0.map(_.label)

// Seems backticks mess with that for now...
// assert(offset == prefix.length)
assert(labels.contains("`org.scala-lang:scala-compiler`"))
assert(labels.contains("`org.scala-lang:scala-library`"))
assert(labels.length == 2)

@Test def deep(): Unit =
val prefix = ""
val input = prefix + "ListBuf"

val (offset, completions0) = completions(input, deep = true)
val labels = completions0.map(_.label)

assert(offset == prefix.length)
assert(labels.contains("scala.collection.mutable.ListBuffer"))

@Test def deepType(): Unit =
val prefix = ""
val input = prefix + "Function2"

val (offset, completions0) = completions(input, deep = true)
val labels = completions0.map(_.label)

assert(offset == prefix.length)
assert(labels.contains("scala.Function2"))

0 comments on commit c3f59a5

Please sign in to comment.