Skip to content

Commit

Permalink
Resolve Scala SDK independently for each target
Browse files Browse the repository at this point in the history
This prepares us for changes in `rules_scala` that will allow customizing the Scala version for each target.
In order to achieve that, we can no longer resolve the Scala SDK globally as the maximal version used. We need to use per-target info.
Scala SDK will be still discovered based on the compiler class path.
Now though we will look into one of a dep providers of a Scala toolchain, namely the `scala_compile_classpath` – a canonical place to put all compile-related jars.

This change is backward-compatible, as the mentioned data is already available.
It is also forward-compatible with the anticipated cross-build feature of `rules_scala`.
(see: bazelbuild/rules_scala#1290)

The aspect will produce additional data – namely few compiler classpath jars per Scala target.
  • Loading branch information
aszady committed Apr 8, 2024
1 parent 7aeaf8a commit 1c14349
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
20 changes: 20 additions & 0 deletions aspects/rules/scala/scala_info.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@io_bazel_rules_scala//scala:providers.bzl", "DepsInfo")
load("//aspects:utils/utils.bzl", "create_proto", "file_location", "is_external", "map", "update_sync_output_groups")

def find_scalac_classpath(runfiles):
Expand Down Expand Up @@ -43,8 +44,27 @@ def extract_scala_info(target, ctx, output_groups, **kwargs):
if hasattr(ctx.rule.attr, "_scala_toolchain"):
toolchain = ctx.toolchains[SCALA_TOOLCHAIN]
common_scalac_opts = toolchain.scalacopts
dep_providers = toolchain.dep_providers
compiler_classpath = _extract_compiler_classpath(dep_providers)
if compiler_classpath:
scala_info["compiler_classpath"] = compiler_classpath
else:
common_scalac_opts = []
scala_info["scalac_opts"] = common_scalac_opts + getattr(ctx.rule.attr, "scalacopts", [])

return create_proto(target, ctx, struct(**scala_info), "scala_target_info"), None

def _extract_compiler_classpath(dep_providers):
for dep_provider in dep_providers:
if DepsInfo not in dep_provider:
continue
deps_info = dep_provider[DepsInfo]
if deps_info.deps_id != "scala_compile_classpath":
continue

compile_jars = []
for dep in deps_info.deps:
compile_jars.extend(dep[JavaInfo].compile_jars.to_list())
classpath = find_scalac_classpath(compile_jars)
return map(file_location, classpath)
return []
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@ class ScalaLanguagePlugin(

var scalaSdk: ScalaSdk? = null

override fun prepareSync(targets: Sequence<BspTargetInfo.TargetInfo>) {
scalaSdk = ScalaSdkResolver(bazelPathsResolver).resolve(targets)
}
private val scalaSdkResolver = ScalaSdkResolver(bazelPathsResolver)

override fun resolveModule(targetInfo: BspTargetInfo.TargetInfo): ScalaModule? {
if (!targetInfo.hasScalaTargetInfo()) {
return null
}
val scalaTargetInfo = targetInfo.scalaTargetInfo
val sdk = getScalaSdkOrThrow()
val sdk = getScalaSdkOrThrow(targetInfo)
val scalacOpts = scalaTargetInfo.scalacOptsList
return ScalaModule(sdk, scalacOpts, javaLanguagePlugin.resolveModule(targetInfo))
}

private fun getScalaSdkOrThrow(): ScalaSdk =
scalaSdk ?: throw RuntimeException("Failed to resolve Scala SDK for project")
private fun getScalaSdkOrThrow(targetInfo: BspTargetInfo.TargetInfo): ScalaSdk =
scalaSdkResolver.resolveSdk(targetInfo) ?: throw RuntimeException("Failed to resolve Scala SDK for target")

override fun dependencySources(
targetInfo: BspTargetInfo.TargetInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@ import kotlin.math.min

class ScalaSdkResolver(private val bazelPathsResolver: BazelPathsResolver) {

fun resolve(targets: Sequence<BspTargetInfo.TargetInfo>): ScalaSdk? =
targets
.mapNotNull(::resolveSdk)
.distinct()
.sortedWith(SCALA_VERSION_COMPARATOR)
.lastOrNull()
fun resolve(targets: Sequence<BspTargetInfo.TargetInfo>): ScalaSdk? = null

private fun resolveSdk(targetInfo: BspTargetInfo.TargetInfo): ScalaSdk? {
if (!targetInfo.hasScalaToolchainInfo()) {
fun resolveSdk(targetInfo: BspTargetInfo.TargetInfo): ScalaSdk? {
if (!targetInfo.hasScalaTargetInfo()) {
return null
}
val scalaToolchain = targetInfo.scalaToolchainInfo
val scalaTarget = targetInfo.scalaTargetInfo
val compilerJars =
bazelPathsResolver.resolvePaths(scalaToolchain.compilerClasspathList).sorted()
bazelPathsResolver.resolvePaths(scalaTarget.compilerClasspathList).sorted()
val maybeVersions = compilerJars.mapNotNull(::extractVersion)
if (maybeVersions.none()) {
return null
Expand Down Expand Up @@ -58,6 +53,6 @@ class ScalaSdkResolver(private val bazelPathsResolver: BazelPathsResolver) {
0
}
private val VERSION_PATTERN =
Pattern.compile("(?:processed_)?scala3?-(?:library|compiler|reflect)(?:_3)?-([.\\d]+)\\.jar")
Pattern.compile("(?:processed_)?scala3?-(?:library|compiler|reflect)(?:_3)?-([.\\d]+)(?:-stamped)?\\.jar")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ message JavaRuntimeInfo {

message ScalaTargetInfo {
repeated string scalac_opts = 1;
repeated FileLocation compiler_classpath = 2;
}

message ScalaToolchainInfo {
Expand Down

0 comments on commit 1c14349

Please sign in to comment.