Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Metals to run DAP for tests #6452

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ lazy val metals = project
V.lsp4j,
// for DAP
V.dap4j,
"ch.epfl.scala" %% "scala-debug-adapter" % V.debugAdapter,
// for finding paths of global log/cache directories
"dev.dirs" % "directories" % "26",
// for Java formatting
Expand Down Expand Up @@ -733,7 +734,6 @@ lazy val metalsDependencies = project
"ch.epfl.scala" % "bloop-maven-plugin" % V.mavenBloop,
"ch.epfl.scala" %% "gradle-bloop" % V.gradleBloop,
"com.sourcegraph" % "semanticdb-java" % V.javaSemanticdb,
"ch.epfl.scala" %% "scala-debug-adapter" % V.debugAdapter intransitive (),
"org.foundweekends.giter8" %% "giter8" % V.gitter8Version intransitive (),
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ final class ImplementationProvider(
classOwner = classOwnerInfoOpt.map(_.symbol),
alternativeSymbols = alternativeSymbols.toList,
overriddenSymbols = info.overriddenSymbols.toList,
properties = if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
properties =
if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
recursiveParents = parents,
annotations = info.annotations.map(_.toString()).toList,
memberDefsAnnotations = Nil,
)
}
}
Expand Down
12 changes: 12 additions & 0 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,18 @@ class Compilers(
.getOrElse(Future(None))
}

def info(
id: BuildTargetIdentifier,
symbol: String,
): Future[Option[PcSymbolInformation]] = {
loadCompiler(id)
.map(
_.info(symbol).asScala
.map(_.asScala.map(PcSymbolInformation.from))
)
.getOrElse(Future(None))
}

private def definition(
params: TextDocumentPositionParams,
token: CancelToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import scala.meta.internal.mtags.ResolvedOverriddenSymbol
import scala.meta.internal.mtags.UnresolvedOverriddenSymbol
import scala.meta.io.AbsolutePath

import org.h2.jdbc.JdbcBatchUpdateException

/**
* Handles caching of Jar Top Level Symbols in H2
*
Expand Down Expand Up @@ -177,6 +179,10 @@ final class JarTopLevels(conn: () => Connection) {
}
// Return number of rows inserted
symbolStmt.executeBatch().sum
} catch {
case e: JdbcBatchUpdateException =>
scribe.warn(e)
0
} finally {
if (symbolStmt != null) symbolStmt.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,9 @@ object MetalsEnrichments
value.replace('/', ':')
else value
}

def symbolToFullQualifiedName: String =
value.replaceAll("/|#", ".").stripSuffix(".")
}

implicit class XtensionTextDocumentSemanticdb(textDocument: s.TextDocument) {
Expand Down Expand Up @@ -1291,6 +1294,30 @@ object MetalsEnrichments
)
}

implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) {
def asScalaMainClass(): Option[b.ScalaMainClass] =
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
decodeJson(params.getData(), classOf[b.ScalaMainClass])
case _ => None
}

def asScalaTestSuites(): Option[b.ScalaTestSuites] =
params.getDataKind() match {
case b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION =>
decodeJson(params.getData(), classOf[b.ScalaTestSuites])
case b.TestParamsDataKind.SCALA_TEST_SUITES =>
for (
tests <- decodeJson(params.getData(), classOf[util.List[String]])
)
yield {
val suites =
tests.map(new b.ScalaTestSuiteSelection(_, Nil.asJava))
new b.ScalaTestSuites(suites, Nil.asJava, Nil.asJava)
}
}
}

/**
* Strips ANSI colors.
* As long as the color codes are valid this should correctly strip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,15 @@ final class RunTestCodeLens(
buildServerCanDebug,
isJVM,
)
} else if (buildServerCanDebug || clientConfig.isRunProvider()) {
} else
codeLenses(
textDocument,
buildTargetId,
classes,
distance,
path,
buildServerCanDebug,
isJVM,
)
} else { Nil }

}

Expand Down Expand Up @@ -165,7 +163,6 @@ final class RunTestCodeLens(
occurence: SymbolOccurrence,
textDocument: TextDocument,
target: BuildTargetIdentifier,
buildServerCanDebug: Boolean,
): Seq[l.Command] = {
if (occurence.symbol.endsWith("#main().")) {
textDocument.symbols
Expand All @@ -182,7 +179,6 @@ final class RunTestCodeLens(
Nil.asJava,
Nil.asJava,
),
buildServerCanDebug,
isJVM = true,
)
else
Expand All @@ -200,7 +196,6 @@ final class RunTestCodeLens(
classes: BuildTargetClasses.Classes,
distance: TokenEditDistance,
path: AbsolutePath,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): Seq[l.CodeLens] = {
for {
Expand All @@ -210,24 +205,24 @@ final class RunTestCodeLens(
commands = {
val main = classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
.getOrElse(Nil)
val tests =
lazy val tests =
// Currently tests can only be run via DAP
if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
if (clientConfig.isDebuggingProvider())
testClasses(target, classes, symbol, isJVM)
else Nil
val fromAnnot = DebugProvider
.mainFromAnnotation(occurrence, textDocument)
.flatMap { symbol =>
classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
}
.getOrElse(Nil)
val javaMains =
if (path.isJava)
javaLenses(occurrence, textDocument, target, buildServerCanDebug)
javaLenses(occurrence, textDocument, target)
else Nil
main ++ tests ++ fromAnnot ++ javaMains
}
Expand Down Expand Up @@ -260,15 +255,15 @@ final class RunTestCodeLens(
val main =
classes.mainClasses
.get(expectedMainClass)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
.getOrElse(Nil)

val fromAnnotations = textDocument.occurrences.flatMap { occ =>
for {
sym <- DebugProvider.mainFromAnnotation(occ, textDocument)
cls <- classes.mainClasses.get(sym)
range <- occurrenceRange(occ, distance)
} yield mainCommand(target, cls, buildServerCanDebug, isJVM).map { cmd =>
} yield mainCommand(target, cls, isJVM).map { cmd =>
new l.CodeLens(range, cmd, null)
}
}.flatten
Expand Down Expand Up @@ -325,7 +320,6 @@ final class RunTestCodeLens(
private def mainCommand(
target: b.BuildTargetIdentifier,
main: b.ScalaMainClass,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): List[l.Command] = {
val javaBinary = buildTargets
Expand Down Expand Up @@ -353,7 +347,7 @@ final class RunTestCodeLens(
sessionParams(target, dataKind, data)
}

if (clientConfig.isDebuggingProvider() && buildServerCanDebug && isJVM)
if (clientConfig.isDebuggingProvider() && isJVM)
List(
command("run", StartRunSession, params),
command("debug", StartDebugSession, params),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.BuildTargets
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.debug.BuildTargetClasses.Classes
import scala.meta.internal.metals.debug.BuildTargetClasses.TestSymbolInfo
import scala.meta.internal.semanticdb.Scala.Descriptor
import scala.meta.internal.semanticdb.Scala.Symbols

Expand All @@ -16,9 +17,9 @@ import ch.epfl.scala.{bsp4j => b}
/**
* In-memory index of main class symbols grouped by their enclosing build target
*/
final class BuildTargetClasses(
buildTargets: BuildTargets
)(implicit val ec: ExecutionContext) {
final class BuildTargetClasses(buildTargets: BuildTargets)(implicit
val ec: ExecutionContext
) {
private val index = TrieMap.empty[b.BuildTargetIdentifier, Classes]
private val jvmRunEnvironments
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
Expand Down Expand Up @@ -56,6 +57,19 @@ final class BuildTargetClasses(
.map(_.fullyQualifiedName)
)

def getTestClasses(
name: String,
id: b.BuildTargetIdentifier,
): List[(String, TestSymbolInfo)] = {
index.get(id).toList.flatMap {
_.testClasses
.filter { case (_, info) =>
info.fullyQualifiedName == name
}
.toList
}
}

private def findClassesBy[A](
f: Classes => Option[A]
): List[(A, b.BuildTargetIdentifier)] = {
Expand All @@ -77,25 +91,22 @@ final class BuildTargetClasses(
Future.successful(())
case (Some(connection), targets0) =>
val targetsList = targets0.asJava
targetsList.forEach(invalidate)
val classes = targets0.map(t => (t, new Classes)).toMap

val updateMainClasses = connection
.mainClasses(new b.ScalaMainClassesParams(targetsList))
.map(cacheMainClasses(classes, _))

// Currently tests are only run using DAP
val updateTestClasses =
if (connection.isDebuggingProvider || connection.isSbt)
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))
else Future.unit
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))

for {
_ <- updateMainClasses
_ <- updateTestClasses
} yield {
targetsList.forEach(invalidate)
classes.foreach { case (id, classes) =>
index.put(id, classes)
}
Expand Down Expand Up @@ -214,7 +225,10 @@ final class BuildTargetClasses(
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
sealed abstract class TestFramework(val canResolveChildren: Boolean) {
def names: List[String]
}

object TestFramework {
def apply(framework: Option[String]): TestFramework = framework
.map {
Expand All @@ -226,11 +240,30 @@ object TestFramework {
}
.getOrElse(Unknown)
}
case object JUnit4 extends TestFramework(true)
case object MUnit extends TestFramework(true)
case object Scalatest extends TestFramework(true)
case object WeaverCatsEffect extends TestFramework(true)
case object Unknown extends TestFramework(false)

case object JUnit4 extends TestFramework(true) {
def names: List[String] = List("com.novocode.junit.JUnitFramework")
}

case object MUnit extends TestFramework(true) {
def names: List[String] = List("munit.Framework")
}

case object Scalatest extends TestFramework(true) {
def names: List[String] =
List(
"org.scalatest.tools.Framework",
"org.scalatest.tools.ScalaTestFramework",
)
}

case object WeaverCatsEffect extends TestFramework(true) {
def names: List[String] = Nil // TODO: find what classes should be here
}

case object Unknown extends TestFramework(false) {
def names: List[String] = Nil
}

object BuildTargetClasses {
type Symbol = String
Expand Down
Loading
Loading