Skip to content

Commit

Permalink
feat: add framework field to the scala-test-classes bsp's endpoint (#…
Browse files Browse the repository at this point in the history
…1695)

Build client can use information about the test framework of tests to provide better UX when running/debugging tests.
More information can be found at build-server-protocol/build-server-protocol#296 (comment).
  • Loading branch information
kpodsiad authored Mar 14, 2022
1 parent 1b6633a commit 2580f21
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 63 deletions.
31 changes: 17 additions & 14 deletions frontend/src/main/scala/bloop/bsp/BloopBspServices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ final class BloopBspServices(
.requestAsync(endpoints.BuildTarget.run)(p => schedule(run(p)))
.requestAsync(endpoints.BuildTarget.cleanCache)(p => schedule(clean(p)))
.requestAsync(endpoints.BuildTarget.scalaMainClasses)(p => schedule(scalaMainClasses(p)))
.requestAsync(endpoints.BuildTarget.scalaTestClasses)(p => schedule(scalaTestClasses(p)))
.requestAsync(ScalaTestClasses.endpoint)(p => schedule(scalaTestClasses(p)))
.requestAsync(endpoints.BuildTarget.dependencySources)(p => schedule(dependencySources(p)))
.requestAsync(endpoints.DebugSession.start)(p => schedule(startDebugSession(p)))
.requestAsync(endpoints.BuildTarget.jvmTestEnvironment)(p => schedule(jvmTestEnvironment(p)))
Expand Down Expand Up @@ -541,25 +541,28 @@ final class BloopBspServices(

def scalaTestClasses(
params: bsp.ScalaTestClassesParams
): BspEndpointResponse[bsp.ScalaTestClassesResult] = {
): BspEndpointResponse[ScalaTestClassesResult] =
ifInitialized(params.originId) { (state: State, logger: BspServerLogger) =>
mapToProjects(params.targets, state) match {
case Left(error) =>
logger.error(error)
Task.now((state, Right(bsp.ScalaTestClassesResult(Nil))))
Task.now((state, Right(ScalaTestClassesResult(Nil))))

case Right(projects) =>
val subTasks = for {
(id, project) <- projects.toList
task = TestTask.findFullyQualifiedTestNames(project, state)
item = task.map(classes => bsp.ScalaTestClassesItem(id, classes))
} yield item

for {
items <- Task.sequence(subTasks)
result = new bsp.ScalaTestClassesResult(items)
} yield (state, Right(result))
}
val subTasks = projects.toList.map { case (id, project) =>
val task = TestTask.findTestNamesWithFramework(project, state)
val item = task.map { classes =>
classes.groupBy(_.framework).map { case (framework, classes) =>
ScalaTestClassesItem(id, classes.flatMap(_.classes), Some(framework))
}.toList
}
item
}

Task.sequence(subTasks).map { items =>
val result = ScalaTestClassesResult(items.flatten)
(state, Right(result))
}
}
}

Expand Down
35 changes: 35 additions & 0 deletions frontend/src/main/scala/bloop/bsp/ScalaTestSuitesResult.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package bloop.bsp

import io.circe.derivation.JsonCodec
import ch.epfl.scala.bsp.BuildTargetIdentifier
import io.circe.derivation.deriveDecoder
import io.circe.derivation.deriveEncoder
import io.circe.Decoder
import io.circe.Encoder
import scala.meta.jsonrpc.Endpoint
import ch.epfl.scala.bsp.ScalaTestClassesParams

object ScalaTestClasses {
val endpoint = new Endpoint[ScalaTestClassesParams, ScalaTestClassesResult]("buildTarget/scalaTestClasses")
}

final case class ScalaTestClassesResult(
items: List[ScalaTestClassesItem]
)

object ScalaTestClassesResult {
implicit val decoder: Decoder[ScalaTestClassesResult] = deriveDecoder
implicit val encoder: Encoder[ScalaTestClassesResult] = deriveEncoder
}

final case class ScalaTestClassesItem(
target: BuildTargetIdentifier,
// Fully qualified names of test classes
classes: List[String],
// Name of the sbt's test framework
framework: Option[String]
)
object ScalaTestClassesItem {
implicit val decoder: Decoder[ScalaTestClassesItem] = deriveDecoder
implicit val encoder: Encoder[ScalaTestClassesItem] = deriveEncoder
}
5 changes: 3 additions & 2 deletions frontend/src/main/scala/bloop/engine/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,10 @@ object Interpreter {
stateWithNoopLogger = state.copy(logger = NoopLogger)
project <- Tasks.pickTestProject(projectName, stateWithNoopLogger)
} yield {
TestTask.findFullyQualifiedTestNames(project, stateWithNoopLogger).map { testsFqcn =>
TestTask.findTestNamesWithFramework(project, stateWithNoopLogger).map { discovered =>
for {
testFqcn <- testsFqcn
classesWithFramework <- discovered
testFqcn <- classesWithFramework.classes
completion <- cmd.format.showTestName(testFqcn)
} state.logger.info(completion)
state
Expand Down
35 changes: 19 additions & 16 deletions frontend/src/main/scala/bloop/engine/tasks/TestTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import bloop.engine.tasks.toolchains.ScalaJsToolchain
import bloop.exec.{Forker, JvmProcessForker}
import bloop.io.AbsolutePath
import bloop.logging.{DebugFilter, Logger}
import bloop.testing.{DiscoveredTestFrameworks, LoggingEventHandler, TestInternals}
import bloop.testing.{DiscoveredTestFrameworks, LoggingEventHandler, TestInternals, FingerprintInfo}
import bloop.util.JavaCompat.EnrichOptional
import monix.eval.Task
import monix.execution.atomic.AtomicBoolean
Expand All @@ -21,6 +21,11 @@ import scala.util.control.NonFatal
import scala.util.{Failure, Success}
import bloop.bsp.ScalaTestSuites

final case class TestFrameworkWithClasses(
framework: String,
classes: List[String]
)

object TestTask {
implicit private val logContext: DebugFilter = DebugFilter.Test

Expand Down Expand Up @@ -316,14 +321,14 @@ object TestTask {
val (subclassPrints, annotatedPrints) = TestInternals.getFingerprints(frameworks)
val definitions = TestInternals.potentialTests(analysis)
val discovered =
Discovery(subclassPrints.map(_._1).toSet, annotatedPrints.map(_._1).toSet)(definitions)
Discovery(subclassPrints.map(_.name).toSet, annotatedPrints.map(_.name).toSet)(definitions)
val tasks = mutable.Map.empty[Framework, mutable.Buffer[TaskDef]]
val seen = mutable.Set.empty[String]
frameworks.foreach(tasks(_) = mutable.Buffer.empty)
discovered.foreach {
case (defn, discovered) =>
TestInternals.matchingFingerprints(subclassPrints, annotatedPrints, discovered).foreach {
case (_, _, framework, fingerprint) =>
case FingerprintInfo(_, _, framework, fingerprint) =>
if (seen.add(defn.name)) {
tasks(framework) += new TaskDef(
defn.name,
Expand All @@ -344,28 +349,26 @@ object TestTask {
* @param project The project for which to find tests.
* @return An array containing all the testsFQCN that were detected.
*/
def findFullyQualifiedTestNames(
def findTestNamesWithFramework(
project: Project,
state: State
): Task[List[String]] = {
import state.logger
): Task[List[TestFrameworkWithClasses]] =
TestTask.discoverTestFrameworks(project, state).map {
case None => List.empty[String]
case None => List.empty
case Some(found) =>
val frameworks = found.frameworks
val lastCompileResult = state.results.lastSuccessfulResultOrEmpty(project)
val analysis = lastCompileResult.previous.analysis().toOption.getOrElse {
logger.debug(s"TestsFQCN was triggered, but no compilation detected for ${project.name}")(
DebugFilter.All
)
state.logger
.debug(s"TestsFQCN was triggered, but no compilation detected for ${project.name}")(
DebugFilter.All
)
Analysis.empty
}
val tests = discoverTests(analysis, frameworks)
tests.toList
.flatMap {
case (framework, tasks) => tasks.map(t => (framework, t))
}
.map(_._2.fullyQualifiedName)
tests.map {
case (framework, tasks) =>
TestFrameworkWithClasses(framework.name, tasks.map(_.fullyQualifiedName))
}.toList
}
}
}
31 changes: 18 additions & 13 deletions frontend/src/main/scala/bloop/testing/TestInternals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import xsbti.compile.CompileAnalysis
import scala.collection.mutable
import scala.util.control.NonFatal

final case class FingerprintInfo[+Print <: Fingerprint](
name: String,
isModule: Boolean,
framework: Framework,
fingerprint: Print
)

object TestInternals {
private final val sbtOrg = "org.scala-sbt"
private final val testAgentId = "test-agent"
Expand All @@ -42,8 +49,6 @@ object TestInternals {
// Cache the resolution of test agent files since it's static (cannot be lazy because depends on logger)
@volatile private var testAgentFiles: Option[Array[AbsolutePath]] = None

private type PrintInfo[F <: Fingerprint] = (String, Boolean, Framework, F)

lazy val filteredLoader = {
val filter = new IncludePackagesFilter(
Set(
Expand Down Expand Up @@ -248,29 +253,29 @@ object TestInternals {

def getFingerprints(
frameworks: Seq[Framework]
): (List[PrintInfo[SubclassFingerprint]], List[PrintInfo[AnnotatedFingerprint]]) = {
): (List[FingerprintInfo[SubclassFingerprint]], List[FingerprintInfo[AnnotatedFingerprint]]) = {
// The tests need to be run with the first matching framework, so we use a LinkedHashSet
// to keep the ordering of `frameworks`.
val subclasses = mutable.LinkedHashSet.empty[PrintInfo[SubclassFingerprint]]
val annotated = mutable.LinkedHashSet.empty[PrintInfo[AnnotatedFingerprint]]
val subclasses = mutable.LinkedHashSet.empty[FingerprintInfo[SubclassFingerprint]]
val annotated = mutable.LinkedHashSet.empty[FingerprintInfo[AnnotatedFingerprint]]
for {
framework <- frameworks
fingerprint <- framework.fingerprints()
} fingerprint match {
case sub: SubclassFingerprint =>
subclasses += ((sub.superclassName, sub.isModule, framework, sub))
subclasses += FingerprintInfo(sub.superclassName, sub.isModule, framework, sub)
case ann: AnnotatedFingerprint =>
annotated += ((ann.annotationName, ann.isModule, framework, ann))
annotated += FingerprintInfo(ann.annotationName, ann.isModule, framework, ann)
}
(subclasses.toList, annotated.toList)
}

// Slightly adapted from sbt/sbt
def matchingFingerprints(
subclassPrints: List[PrintInfo[SubclassFingerprint]],
annotatedPrints: List[PrintInfo[AnnotatedFingerprint]],
subclassPrints: List[FingerprintInfo[SubclassFingerprint]],
annotatedPrints: List[FingerprintInfo[AnnotatedFingerprint]],
d: Discovered
): List[PrintInfo[Fingerprint]] = {
): List[FingerprintInfo[Fingerprint]] = {
defined(subclassPrints, d.baseClasses, d.isModule) ++
defined(annotatedPrints, d.annotations, d.isModule)
}
Expand Down Expand Up @@ -316,11 +321,11 @@ object TestInternals {

// Slightly adapted from sbt/sbt
private def defined[T <: Fingerprint](
in: List[PrintInfo[T]],
in: List[FingerprintInfo[T]],
names: Set[String],
IsModule: Boolean
): List[PrintInfo[T]] = {
in collect { case info @ (name, IsModule, _, _) if names(name) => info }
): List[FingerprintInfo[T]] = {
in.collect { case info @ FingerprintInfo(name, IsModule, _, _) if names(name) => info }
}

private def loadFramework(loader: ClassLoader, fqn: String, logger: Logger): Option[Framework] = {
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/test/scala/bloop/bsp/BspBaseSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ abstract class BspBaseSuite extends BaseSuite with BspClientTest {
TestUtil.await(FiniteDuration(5, "s"))(task)
}

def testClasses(project: TestProject): bsp.ScalaTestClassesResult = {
def testClasses(project: TestProject): ScalaTestClassesResult = {
val task = runAfterTargets(project) { target =>
val params = bsp.ScalaTestClassesParams(List(target), None)
endpoints.BuildTarget.scalaTestClasses.request(params).map {
ScalaTestClasses.endpoint.request(params).map {
case Left(error) => fail(s"Received error $error")
case Right(result) => result
}
Expand Down
33 changes: 17 additions & 16 deletions frontend/src/test/scala/bloop/bsp/BspProtocolSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import bloop.bsp.BloopBspDefinitions.BloopExtraBuildParams
import io.circe.Json
import bloop.testing.DiffAssertions.TestFailedException
import bloop.data.SourcesGlobs
import scala.collection.immutable

object TcpBspProtocolSpec extends BspProtocolSpec(BspProtocol.Tcp)
object LocalBspProtocolSpec extends BspProtocolSpec(BspProtocol.Local)
Expand Down Expand Up @@ -288,23 +289,22 @@ class BspProtocolSpec(
loadBspBuildFromResources("cross-test-build-scalajs-0.6", workspace, logger) { build =>
val project = build.projectFor("test-project-test")
val compiledState = build.state.compile(project, timeout = 120)
val expectedClasses = Set(
"JUnitTest",
"ScalaTestTest",
"ScalaCheckTest",
"WritingTest",
"Specs2Test",
"EternalUTest",
"UTestTest",
"ResourcesTest"
).map("hello." + _)

val testClasses = compiledState.testClasses(project)
val items = testClasses.items
assert(items.size == 1)
val expectedSuites = Set(
("JUnit", List("hello.JUnitTest")),
("ScalaCheck", List("hello.ScalaCheckTest")),
("specs2", List("hello.Specs2Test")),
("utest", List("hello.EternalUTest", "hello.UTestTest")),
("ScalaTest", List("hello.ScalaTestTest", "hello.WritingTest", "hello.ResourcesTest")),
).map { case (framework, classes) =>
ScalaTestClassesItem(project.bspId, classes, Some(framework))
}

val testSuites = compiledState.testClasses(project)
val items = testSuites.items

val classes = items.head.classes.toSet
try assertEquals(classes, expectedClasses)
assert(items.size == expectedSuites.size)

try assertEquals(items.toSet, expectedSuites)
catch { case t: TestFailedException => logger.dump(); throw t }
}
}
Expand Down Expand Up @@ -539,4 +539,5 @@ class BspProtocolSpec(
}
}
}

}

0 comments on commit 2580f21

Please sign in to comment.