Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scala: support multiple scala versions #14425

Merged
merged 8 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/python/pants/backend/java/compile/javac.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def compile_java_source(
)

usercp = "__cp"
user_classpath = Classpath(direct_dependency_classpath_entries)
user_classpath = Classpath(direct_dependency_classpath_entries, request.resolve)
classpath_arg = ":".join(user_classpath.root_immutable_inputs_args(prefix=usercp))
immutable_input_digests = dict(user_classpath.root_immutable_inputs(prefix=usercp))

Expand Down
11 changes: 7 additions & 4 deletions src/python/pants/backend/scala/compile/scalac.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ async def compile_scala_source(
for filename in dependency.filenames
]

scala_version = scala.version_for_resolve(request.resolve.name)

# TODO(14171): Stop-gap for making sure that scala supplies all of its dependencies to
# deploy targets.
if not any(
filename.startswith("org.scala-lang_scala-library_") for filename in all_dependency_jars
):
scala_library = await Get(ClasspathEntry, ScalaLibraryRequest(scala.version))
scala_library = await Get(ClasspathEntry, ScalaLibraryRequest(scala_version))
direct_dependency_classpath_entries += (scala_library,)

component_members_with_sources = tuple(
Expand Down Expand Up @@ -130,7 +132,8 @@ async def compile_scala_source(
scalac_plugins_relpath = "__plugincp"
usercp = "__cp"

user_classpath = Classpath(direct_dependency_classpath_entries)
user_classpath = Classpath(direct_dependency_classpath_entries, request.resolve)

tool_classpath, sources_digest = await MultiGet(
Get(
ToolClasspath,
Expand All @@ -140,12 +143,12 @@ async def compile_scala_source(
Coordinate(
group="org.scala-lang",
artifact="scala-compiler",
version=scala.version,
version=scala_version,
),
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=scala.version,
version=scala_version,
),
]
),
Expand Down
65 changes: 64 additions & 1 deletion src/python/pants/backend/scala/compile/scalac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import itertools
from textwrap import dedent

import pytest
Expand All @@ -20,7 +21,7 @@
from pants.engine.fs import FileDigest
from pants.engine.target import CoarsenedTargets
from pants.jvm import jdk_rules, testutil
from pants.jvm.compile import CompileResult, FallibleClasspathEntry
from pants.jvm.compile import ClasspathEntry, CompileResult, FallibleClasspathEntry
from pants.jvm.resolve.common import ArtifactRequirement, Coordinate, Coordinates
from pants.jvm.resolve.coursier_fetch import CoursierLockfileEntry, CoursierResolvedLockfile
from pants.jvm.resolve.coursier_fetch import rules as coursier_fetch_rules
Expand Down Expand Up @@ -52,6 +53,7 @@ def rule_runner() -> RuleRunner:
QueryRule(CoarsenedTargets, (Addresses,)),
QueryRule(FallibleClasspathEntry, (CompileScalaSourceRequest,)),
QueryRule(RenderedClasspath, (CompileScalaSourceRequest,)),
QueryRule(ClasspathEntry, (CompileScalaSourceRequest,)),
],
target_types=[JvmArtifactTarget, ScalaSourcesGeneratorTarget, ScalacPluginTarget],
)
Expand Down Expand Up @@ -649,3 +651,64 @@ def example: Option[String] = {
resolve=make_resolve(rule_runner),
)
rule_runner.request(RenderedClasspath, [request])


@maybe_skip_jdk_test
def test_compile_with_multiple_scala_versions(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"BUILD": dedent(
"""\
scala_sources(
name = 'main',
)
"""
),
"Example.scala": SCALA_LIB_SOURCE,
"3rdparty/jvm/scala2.12.lock": EMPTY_LOCKFILE,
"3rdparty/jvm/scala2.13.lock": EMPTY_LOCKFILE,
}
)
rule_runner.set_options(
['--scala-version-for-resolve={"scala2.12":"2.12.15","scala2.13":"2.13.8"}'],
env_inherit=PYTHON_BOOTSTRAP_ENV,
)
classpath_2_12 = rule_runner.request(
ClasspathEntry,
[
CompileScalaSourceRequest(
component=expect_single_expanded_coarsened_target(
rule_runner, Address(spec_path="", target_name="main")
),
resolve=make_resolve(rule_runner, "scala2.12", "3rdparty/jvm/scala2.12.lock"),
)
],
)
entries_2_12 = list(ClasspathEntry.closure([classpath_2_12]))
filenames_2_12 = sorted(
itertools.chain.from_iterable(entry.filenames for entry in entries_2_12)
)
assert filenames_2_12 == [
".Example.scala.main.scalac.jar",
"org.scala-lang_scala-library_2.12.15.jar",
]

classpath_2_13 = rule_runner.request(
ClasspathEntry,
[
CompileScalaSourceRequest(
component=expect_single_expanded_coarsened_target(
rule_runner, Address(spec_path="", target_name="main")
),
resolve=make_resolve(rule_runner, "scala2.13", "3rdparty/jvm/scala2.13.lock"),
)
],
)
entries_2_13 = list(ClasspathEntry.closure([classpath_2_13]))
filenames_2_13 = sorted(
itertools.chain.from_iterable(entry.filenames for entry in entries_2_13)
)
assert filenames_2_13 == [
".Example.scala.main.scalac.jar",
"org.scala-lang_scala-library_2.13.8.jar",
]
50 changes: 25 additions & 25 deletions src/python/pants/backend/scala/goals/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,31 @@ async def create_scala_repl_request(
request: ScalaRepl, bash: BashBinary, jdk_setup: JdkSetup, scala_subsystem: ScalaSubsystem
) -> ReplRequest:
jdk = jdk_setup.jdk
user_classpath, tool_classpath = await MultiGet(
Get(Classpath, Addresses, request.addresses),
Get(
ToolClasspath,
ToolClasspathRequest(
prefix="__toolcp",
artifact_requirements=ArtifactRequirements.from_coordinates(
[
Coordinate(
group="org.scala-lang",
artifact="scala-compiler",
version=scala_subsystem.version,
),
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=scala_subsystem.version,
),
Coordinate(
group="org.scala-lang",
artifact="scala-reflect",
version=scala_subsystem.version,
),
]
),

user_classpath = await Get(Classpath, Addresses, request.addresses)
scala_version = scala_subsystem.version_for_resolve(user_classpath.resolve.name)
tool_classpath = await Get(
ToolClasspath,
ToolClasspathRequest(
prefix="__toolcp",
artifact_requirements=ArtifactRequirements.from_coordinates(
[
Coordinate(
group="org.scala-lang",
artifact="scala-compiler",
version=scala_version,
),
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=scala_version,
),
Coordinate(
group="org.scala-lang",
artifact="scala-reflect",
version=scala_version,
),
]
),
),
)
Expand Down
44 changes: 42 additions & 2 deletions src/python/pants/backend/scala/subsystems/scala.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,56 @@

from __future__ import annotations

import logging
from typing import cast

from pants.option.option_types import StrOption
from pants.option.subsystem import Subsystem

DEFAULT_SCALA_VERSION = "2.13.6"

_logger = logging.getLogger(__name__)


class ScalaSubsystem(Subsystem):
options_scope = "scala"
help = "Scala programming language"

version = StrOption(
"--version", default=DEFAULT_SCALA_VERSION, help="The version of Scala to use"
)
"--version",
default=DEFAULT_SCALA_VERSION,
help=(
"The version of Scala to use.\n\n"
"This option is deprecated in favor of the `[scala].version_for_resolve` option. If "
"`[scala].version_for_resolve` does not have an entry for a resolve, then the value of "
"this option will be used as the Scala version for that resolve."
),
).deprecated(removal_version="2.11.0.dev0", hint="Use `[scala].version_for_resolve` instead.")

@classmethod
def register_options(cls, register):
super().register_options(register)

register(
"--version-for-resolve",
type=dict,
help=(
"A dictionary mapping the name of a resolve to the Scala version to use for all Scala "
"targets consuming that resolve.\n\n"
'All Scala-compiled jars on a resolve\'s classpath must be "compatible" with one another and '
"with all Scala-compiled first-party sources from `scala_sources` (and other Scala target types) "
"using that resolve. The option sets the Scala version that will be used to compile all "
"first-party sources using the resolve. This ensures that the compatibility property is "
"maintained for a resolve. To support multiple Scala versions, use multiple resolves."
),
)

@property
def _version_for_resolve(self) -> dict[str, str]:
return cast("dict[str, str]", self.options.version_for_resolve)

def version_for_resolve(self, resolve: str) -> str:
version = self._version_for_resolve.get(resolve)
if version:
return version
return self.version
26 changes: 16 additions & 10 deletions src/python/pants/jvm/classpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Iterator

from pants.engine.collection import Collection
from pants.engine.fs import Digest
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
Expand All @@ -17,7 +17,8 @@
logger = logging.getLogger(__name__)


class Classpath(Collection[ClasspathEntry]):
@dataclass(frozen=True)
class Classpath:
"""A transitive classpath which is sufficient to launch the target(s) it was generated for.

There are two primary ways to consume a Classpath:
Expand All @@ -31,33 +32,38 @@ class Classpath(Collection[ClasspathEntry]):
This classpath is guaranteed to contain only JAR files.
"""

entries: tuple[ClasspathEntry, ...]
resolve: CoursierResolveKey

def args(self, *, prefix: str = "") -> Iterator[str]:
"""All transitive filenames for this Classpath."""
return ClasspathEntry.args(ClasspathEntry.closure(self), prefix=prefix)
return ClasspathEntry.args(ClasspathEntry.closure(self.entries), prefix=prefix)

def root_args(self, *, prefix: str = "") -> Iterator[str]:
"""The root filenames for this Classpath."""
return ClasspathEntry.args(self, prefix=prefix)
return ClasspathEntry.args(self.entries, prefix=prefix)

def digests(self) -> Iterator[Digest]:
"""All transitive Digests for this Classpath."""
return (entry.digest for entry in ClasspathEntry.closure(self))
return (entry.digest for entry in ClasspathEntry.closure(self.entries))

def immutable_inputs(self, *, prefix: str = "") -> Iterator[tuple[str, Digest]]:
"""Returns (relpath, Digest) tuples for use with `Process.immutable_input_digests`."""
return ClasspathEntry.immutable_inputs(ClasspathEntry.closure(self), prefix=prefix)
return ClasspathEntry.immutable_inputs(ClasspathEntry.closure(self.entries), prefix=prefix)

def immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]:
"""Returns relative filenames for the given entries to be used as immutable_inputs."""
return ClasspathEntry.immutable_inputs_args(ClasspathEntry.closure(self), prefix=prefix)
return ClasspathEntry.immutable_inputs_args(
ClasspathEntry.closure(self.entries), prefix=prefix
)

def root_immutable_inputs(self, *, prefix: str = "") -> Iterator[tuple[str, Digest]]:
"""Returns root (relpath, Digest) tuples for use with `Process.immutable_input_digests`."""
return ClasspathEntry.immutable_inputs(self, prefix=prefix)
return ClasspathEntry.immutable_inputs(self.entries, prefix=prefix)

def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]:
"""Returns root relative filenames for the given entries to be used as immutable_inputs."""
return ClasspathEntry.immutable_inputs_args(self, prefix=prefix)
return ClasspathEntry.immutable_inputs_args(self.entries, prefix=prefix)


@rule
Expand All @@ -81,7 +87,7 @@ async def classpath(
for t in coarsened_targets
)

return Classpath(classpath_entries)
return Classpath(classpath_entries, resolve)


def rules():
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/jvm/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ async def render_classpath_entry(
@rule
async def render_classpath(classpath: Classpath) -> RenderedClasspath:
rendered_classpaths = await MultiGet(
Get(RenderedClasspath, ClasspathEntry, cpe) for cpe in ClasspathEntry.closure(classpath)
Get(RenderedClasspath, ClasspathEntry, cpe)
for cpe in ClasspathEntry.closure(classpath.entries)
)
return RenderedClasspath({k: v for rc in rendered_classpaths for k, v in rc.content.items()})

Expand Down