diff --git a/frontend/src/main/scala/bloop/bsp/BloopBspServices.scala b/frontend/src/main/scala/bloop/bsp/BloopBspServices.scala index 1b39e678ff..5437c52c7b 100644 --- a/frontend/src/main/scala/bloop/bsp/BloopBspServices.scala +++ b/frontend/src/main/scala/bloop/bsp/BloopBspServices.scala @@ -109,7 +109,7 @@ final class BloopBspServices( .requestAsync(endpoints.BuildTarget.run)(p => schedule(run(p))) .requestAsync(endpoints.BuildTarget.cleanCache)(p => schedule(clean(p))) .requestAsync(endpoints.BuildTarget.scalaMainClasses)(p => schedule(scalaMainClasses(p))) - .requestAsync(endpoints.BuildTarget.scalaTestClasses)(p => schedule(scalaTestClasses(p))) + .requestAsync(ScalaTestClasses.endpoint)(p => schedule(scalaTestClasses(p))) .requestAsync(endpoints.BuildTarget.dependencySources)(p => schedule(dependencySources(p))) .requestAsync(endpoints.DebugSession.start)(p => schedule(startDebugSession(p))) .requestAsync(endpoints.BuildTarget.jvmTestEnvironment)(p => schedule(jvmTestEnvironment(p))) @@ -541,25 +541,28 @@ final class BloopBspServices( def scalaTestClasses( params: bsp.ScalaTestClassesParams - ): BspEndpointResponse[bsp.ScalaTestClassesResult] = { + ): BspEndpointResponse[ScalaTestClassesResult] = ifInitialized(params.originId) { (state: State, logger: BspServerLogger) => mapToProjects(params.targets, state) match { case Left(error) => logger.error(error) - Task.now((state, Right(bsp.ScalaTestClassesResult(Nil)))) + Task.now((state, Right(ScalaTestClassesResult(Nil)))) case Right(projects) => - val subTasks = for { - (id, project) <- projects.toList - task = TestTask.findFullyQualifiedTestNames(project, state) - item = task.map(classes => bsp.ScalaTestClassesItem(id, classes)) - } yield item - - for { - items <- Task.sequence(subTasks) - result = new bsp.ScalaTestClassesResult(items) - } yield (state, Right(result)) - } + val subTasks = projects.toList.map { case (id, project) => + val task = TestTask.findTestNamesWithFramework(project, state) + val item = task.map { classes => + classes.groupBy(_.framework).map { case (framework, classes) => + ScalaTestClassesItem(id, classes.flatMap(_.classes), Some(framework)) + }.toList + } + item + } + + Task.sequence(subTasks).map { items => + val result = ScalaTestClassesResult(items.flatten) + (state, Right(result)) + } } } diff --git a/frontend/src/main/scala/bloop/bsp/ScalaTestSuitesResult.scala b/frontend/src/main/scala/bloop/bsp/ScalaTestSuitesResult.scala new file mode 100644 index 0000000000..9a53150f56 --- /dev/null +++ b/frontend/src/main/scala/bloop/bsp/ScalaTestSuitesResult.scala @@ -0,0 +1,35 @@ +package bloop.bsp + +import io.circe.derivation.JsonCodec +import ch.epfl.scala.bsp.BuildTargetIdentifier +import io.circe.derivation.deriveDecoder +import io.circe.derivation.deriveEncoder +import io.circe.Decoder +import io.circe.Encoder +import scala.meta.jsonrpc.Endpoint +import ch.epfl.scala.bsp.ScalaTestClassesParams + +object ScalaTestClasses { + val endpoint = new Endpoint[ScalaTestClassesParams, ScalaTestClassesResult]("buildTarget/scalaTestClasses") +} + +final case class ScalaTestClassesResult( + items: List[ScalaTestClassesItem] +) + +object ScalaTestClassesResult { + implicit val decoder: Decoder[ScalaTestClassesResult] = deriveDecoder + implicit val encoder: Encoder[ScalaTestClassesResult] = deriveEncoder +} + +final case class ScalaTestClassesItem( + target: BuildTargetIdentifier, + // Fully qualified names of test classes + classes: List[String], + // Name of the sbt's test framework + framework: Option[String] +) +object ScalaTestClassesItem { + implicit val decoder: Decoder[ScalaTestClassesItem] = deriveDecoder + implicit val encoder: Encoder[ScalaTestClassesItem] = deriveEncoder +} diff --git a/frontend/src/main/scala/bloop/engine/Interpreter.scala b/frontend/src/main/scala/bloop/engine/Interpreter.scala index 4d38ec34d0..f4d08111d3 100644 --- a/frontend/src/main/scala/bloop/engine/Interpreter.scala +++ b/frontend/src/main/scala/bloop/engine/Interpreter.scala @@ -462,9 +462,10 @@ object Interpreter { stateWithNoopLogger = state.copy(logger = NoopLogger) project <- Tasks.pickTestProject(projectName, stateWithNoopLogger) } yield { - TestTask.findFullyQualifiedTestNames(project, stateWithNoopLogger).map { testsFqcn => + TestTask.findTestNamesWithFramework(project, stateWithNoopLogger).map { discovered => for { - testFqcn <- testsFqcn + classesWithFramework <- discovered + testFqcn <- classesWithFramework.classes completion <- cmd.format.showTestName(testFqcn) } state.logger.info(completion) state diff --git a/frontend/src/main/scala/bloop/engine/tasks/TestTask.scala b/frontend/src/main/scala/bloop/engine/tasks/TestTask.scala index 6023ced852..2bf4be8a70 100644 --- a/frontend/src/main/scala/bloop/engine/tasks/TestTask.scala +++ b/frontend/src/main/scala/bloop/engine/tasks/TestTask.scala @@ -8,7 +8,7 @@ import bloop.engine.tasks.toolchains.ScalaJsToolchain import bloop.exec.{Forker, JvmProcessForker} import bloop.io.AbsolutePath import bloop.logging.{DebugFilter, Logger} -import bloop.testing.{DiscoveredTestFrameworks, LoggingEventHandler, TestInternals} +import bloop.testing.{DiscoveredTestFrameworks, LoggingEventHandler, TestInternals, FingerprintInfo} import bloop.util.JavaCompat.EnrichOptional import monix.eval.Task import monix.execution.atomic.AtomicBoolean @@ -21,6 +21,11 @@ import scala.util.control.NonFatal import scala.util.{Failure, Success} import bloop.bsp.ScalaTestSuites +final case class TestFrameworkWithClasses( + framework: String, + classes: List[String] +) + object TestTask { implicit private val logContext: DebugFilter = DebugFilter.Test @@ -316,14 +321,14 @@ object TestTask { val (subclassPrints, annotatedPrints) = TestInternals.getFingerprints(frameworks) val definitions = TestInternals.potentialTests(analysis) val discovered = - Discovery(subclassPrints.map(_._1).toSet, annotatedPrints.map(_._1).toSet)(definitions) + Discovery(subclassPrints.map(_.name).toSet, annotatedPrints.map(_.name).toSet)(definitions) val tasks = mutable.Map.empty[Framework, mutable.Buffer[TaskDef]] val seen = mutable.Set.empty[String] frameworks.foreach(tasks(_) = mutable.Buffer.empty) discovered.foreach { case (defn, discovered) => TestInternals.matchingFingerprints(subclassPrints, annotatedPrints, discovered).foreach { - case (_, _, framework, fingerprint) => + case FingerprintInfo(_, _, framework, fingerprint) => if (seen.add(defn.name)) { tasks(framework) += new TaskDef( defn.name, @@ -344,28 +349,26 @@ object TestTask { * @param project The project for which to find tests. * @return An array containing all the testsFQCN that were detected. */ - def findFullyQualifiedTestNames( + def findTestNamesWithFramework( project: Project, state: State - ): Task[List[String]] = { - import state.logger + ): Task[List[TestFrameworkWithClasses]] = TestTask.discoverTestFrameworks(project, state).map { - case None => List.empty[String] + case None => List.empty case Some(found) => val frameworks = found.frameworks val lastCompileResult = state.results.lastSuccessfulResultOrEmpty(project) val analysis = lastCompileResult.previous.analysis().toOption.getOrElse { - logger.debug(s"TestsFQCN was triggered, but no compilation detected for ${project.name}")( - DebugFilter.All - ) + state.logger + .debug(s"TestsFQCN was triggered, but no compilation detected for ${project.name}")( + DebugFilter.All + ) Analysis.empty } val tests = discoverTests(analysis, frameworks) - tests.toList - .flatMap { - case (framework, tasks) => tasks.map(t => (framework, t)) - } - .map(_._2.fullyQualifiedName) + tests.map { + case (framework, tasks) => + TestFrameworkWithClasses(framework.name, tasks.map(_.fullyQualifiedName)) + }.toList } - } } diff --git a/frontend/src/main/scala/bloop/testing/TestInternals.scala b/frontend/src/main/scala/bloop/testing/TestInternals.scala index a83b560c3c..e20f26fd18 100644 --- a/frontend/src/main/scala/bloop/testing/TestInternals.scala +++ b/frontend/src/main/scala/bloop/testing/TestInternals.scala @@ -32,6 +32,13 @@ import xsbti.compile.CompileAnalysis import scala.collection.mutable import scala.util.control.NonFatal +final case class FingerprintInfo[+Print <: Fingerprint]( + name: String, + isModule: Boolean, + framework: Framework, + fingerprint: Print +) + object TestInternals { private final val sbtOrg = "org.scala-sbt" private final val testAgentId = "test-agent" @@ -42,8 +49,6 @@ object TestInternals { // Cache the resolution of test agent files since it's static (cannot be lazy because depends on logger) @volatile private var testAgentFiles: Option[Array[AbsolutePath]] = None - private type PrintInfo[F <: Fingerprint] = (String, Boolean, Framework, F) - lazy val filteredLoader = { val filter = new IncludePackagesFilter( Set( @@ -248,29 +253,29 @@ object TestInternals { def getFingerprints( frameworks: Seq[Framework] - ): (List[PrintInfo[SubclassFingerprint]], List[PrintInfo[AnnotatedFingerprint]]) = { + ): (List[FingerprintInfo[SubclassFingerprint]], List[FingerprintInfo[AnnotatedFingerprint]]) = { // The tests need to be run with the first matching framework, so we use a LinkedHashSet // to keep the ordering of `frameworks`. - val subclasses = mutable.LinkedHashSet.empty[PrintInfo[SubclassFingerprint]] - val annotated = mutable.LinkedHashSet.empty[PrintInfo[AnnotatedFingerprint]] + val subclasses = mutable.LinkedHashSet.empty[FingerprintInfo[SubclassFingerprint]] + val annotated = mutable.LinkedHashSet.empty[FingerprintInfo[AnnotatedFingerprint]] for { framework <- frameworks fingerprint <- framework.fingerprints() } fingerprint match { case sub: SubclassFingerprint => - subclasses += ((sub.superclassName, sub.isModule, framework, sub)) + subclasses += FingerprintInfo(sub.superclassName, sub.isModule, framework, sub) case ann: AnnotatedFingerprint => - annotated += ((ann.annotationName, ann.isModule, framework, ann)) + annotated += FingerprintInfo(ann.annotationName, ann.isModule, framework, ann) } (subclasses.toList, annotated.toList) } // Slightly adapted from sbt/sbt def matchingFingerprints( - subclassPrints: List[PrintInfo[SubclassFingerprint]], - annotatedPrints: List[PrintInfo[AnnotatedFingerprint]], + subclassPrints: List[FingerprintInfo[SubclassFingerprint]], + annotatedPrints: List[FingerprintInfo[AnnotatedFingerprint]], d: Discovered - ): List[PrintInfo[Fingerprint]] = { + ): List[FingerprintInfo[Fingerprint]] = { defined(subclassPrints, d.baseClasses, d.isModule) ++ defined(annotatedPrints, d.annotations, d.isModule) } @@ -316,11 +321,11 @@ object TestInternals { // Slightly adapted from sbt/sbt private def defined[T <: Fingerprint]( - in: List[PrintInfo[T]], + in: List[FingerprintInfo[T]], names: Set[String], IsModule: Boolean - ): List[PrintInfo[T]] = { - in collect { case info @ (name, IsModule, _, _) if names(name) => info } + ): List[FingerprintInfo[T]] = { + in.collect { case info @ FingerprintInfo(name, IsModule, _, _) if names(name) => info } } private def loadFramework(loader: ClassLoader, fqn: String, logger: Logger): Option[Framework] = { diff --git a/frontend/src/test/scala/bloop/bsp/BspBaseSuite.scala b/frontend/src/test/scala/bloop/bsp/BspBaseSuite.scala index 9d6833c685..3fc9483269 100644 --- a/frontend/src/test/scala/bloop/bsp/BspBaseSuite.scala +++ b/frontend/src/test/scala/bloop/bsp/BspBaseSuite.scala @@ -310,10 +310,10 @@ abstract class BspBaseSuite extends BaseSuite with BspClientTest { TestUtil.await(FiniteDuration(5, "s"))(task) } - def testClasses(project: TestProject): bsp.ScalaTestClassesResult = { + def testClasses(project: TestProject): ScalaTestClassesResult = { val task = runAfterTargets(project) { target => val params = bsp.ScalaTestClassesParams(List(target), None) - endpoints.BuildTarget.scalaTestClasses.request(params).map { + ScalaTestClasses.endpoint.request(params).map { case Left(error) => fail(s"Received error $error") case Right(result) => result } diff --git a/frontend/src/test/scala/bloop/bsp/BspProtocolSpec.scala b/frontend/src/test/scala/bloop/bsp/BspProtocolSpec.scala index 21f9141121..0f9c727b32 100644 --- a/frontend/src/test/scala/bloop/bsp/BspProtocolSpec.scala +++ b/frontend/src/test/scala/bloop/bsp/BspProtocolSpec.scala @@ -19,6 +19,7 @@ import bloop.bsp.BloopBspDefinitions.BloopExtraBuildParams import io.circe.Json import bloop.testing.DiffAssertions.TestFailedException import bloop.data.SourcesGlobs +import scala.collection.immutable object TcpBspProtocolSpec extends BspProtocolSpec(BspProtocol.Tcp) object LocalBspProtocolSpec extends BspProtocolSpec(BspProtocol.Local) @@ -288,23 +289,22 @@ class BspProtocolSpec( loadBspBuildFromResources("cross-test-build-scalajs-0.6", workspace, logger) { build => val project = build.projectFor("test-project-test") val compiledState = build.state.compile(project, timeout = 120) - val expectedClasses = Set( - "JUnitTest", - "ScalaTestTest", - "ScalaCheckTest", - "WritingTest", - "Specs2Test", - "EternalUTest", - "UTestTest", - "ResourcesTest" - ).map("hello." + _) - - val testClasses = compiledState.testClasses(project) - val items = testClasses.items - assert(items.size == 1) + val expectedSuites = Set( + ("JUnit", List("hello.JUnitTest")), + ("ScalaCheck", List("hello.ScalaCheckTest")), + ("specs2", List("hello.Specs2Test")), + ("utest", List("hello.EternalUTest", "hello.UTestTest")), + ("ScalaTest", List("hello.ScalaTestTest", "hello.WritingTest", "hello.ResourcesTest")), + ).map { case (framework, classes) => + ScalaTestClassesItem(project.bspId, classes, Some(framework)) + } + + val testSuites = compiledState.testClasses(project) + val items = testSuites.items - val classes = items.head.classes.toSet - try assertEquals(classes, expectedClasses) + assert(items.size == expectedSuites.size) + + try assertEquals(items.toSet, expectedSuites) catch { case t: TestFailedException => logger.dump(); throw t } } } @@ -539,4 +539,5 @@ class BspProtocolSpec( } } } + }