Skip to content

Commit

Permalink
Allow for codegen targets to be used directly by JVM compiler requests (
Browse files Browse the repository at this point in the history
#14751)

This amends `ClasspathEntryRequest.for_targets` to accept a codegen source request target and output a `ClasspathEntryRequest` that corresponds to a relevant JVM compilation request.

With this approach, it seems that any `GenerateSourcesRequest` with `output` that is compatible with a `ClasspathEntryRequest` will Just Work™. To demonstrate this in action, there's backend registration support for the Java and Scala protobuf generators included in this PR too. 

The main current limitation is when multiple language backends are enabled for a given codegen source type. A solution for this would address #14041, and realistically needs to take multiple JVM languages into account.
  • Loading branch information
Christopher Neugebauer authored Mar 11, 2022
1 parent 86d9183 commit 60c19d1
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 57 deletions.
33 changes: 33 additions & 0 deletions src/python/pants/backend/codegen/protobuf/java/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

"""Generate Java sources from Protocol Buffers (Protobufs).
See https://www.pantsbuild.org/docs/protobuf.
"""

from pants.backend.codegen import export_codegen_goal
from pants.backend.codegen.protobuf import protobuf_dependency_inference
from pants.backend.codegen.protobuf import tailor as protobuf_tailor
from pants.backend.codegen.protobuf.java.rules import rules as java_rules
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules
from pants.core.util_rules import stripped_source_files


def rules():
return [
*java_rules(),
*protobuf_dependency_inference.rules(),
*protobuf_tailor.rules(),
*export_codegen_goal.rules(),
*protobuf_target_rules(),
*stripped_source_files.rules(),
]


def target_types():
return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget]
13 changes: 12 additions & 1 deletion src/python/pants/backend/codegen/protobuf/java/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@


from pants.backend.codegen.protobuf.protoc import Protoc
from pants.backend.codegen.protobuf.target_types import ProtobufSourceField
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourceField,
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.java.target_types import JavaSourceField
from pants.backend.python.util_rules import pex
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
Expand All @@ -28,6 +32,7 @@
TransitiveTargetsRequest,
)
from pants.engine.unions import UnionRule
from pants.jvm.target_types import JvmJdkField
from pants.source.source_root import SourceRoot, SourceRootRequest
from pants.util.logging import LogLevel

Expand Down Expand Up @@ -116,9 +121,15 @@ async def generate_java_from_protobuf(
return GeneratedSources(source_root_restored)


class PrefixedJvmJdkField(JvmJdkField):
alias = "jvm_jdk"


def rules():
return [
*collect_rules(),
*pex.rules(),
UnionRule(GenerateSourcesRequest, GenerateJavaFromProtobufRequest),
ProtobufSourceTarget.register_plugin_field(PrefixedJvmJdkField),
ProtobufSourcesGeneratorTarget.register_plugin_field(PrefixedJvmJdkField),
]
33 changes: 33 additions & 0 deletions src/python/pants/backend/codegen/protobuf/scala/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

"""Generate Scala sources from Protocol Buffers (Protobufs).
See https://www.pantsbuild.org/docs/protobuf.
"""

from pants.backend.codegen import export_codegen_goal
from pants.backend.codegen.protobuf import protobuf_dependency_inference
from pants.backend.codegen.protobuf import tailor as protobuf_tailor
from pants.backend.codegen.protobuf.scala.rules import rules as scala_rules
from pants.backend.codegen.protobuf.target_types import (
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules
from pants.core.util_rules import stripped_source_files


def rules():
return [
*scala_rules(),
*protobuf_dependency_inference.rules(),
*protobuf_tailor.rules(),
*export_codegen_goal.rules(),
*protobuf_target_rules(),
*stripped_source_files.rules(),
]


def target_types():
return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget]
13 changes: 7 additions & 6 deletions src/python/pants/backend/java/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.bsp.spec import JvmBuildTarget
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.jvm.compile import (
ClasspathEntryRequest,
ClasspathEntryRequestFactory,
FallibleClasspathEntry,
)
from pants.jvm.resolve.key import CoursierResolveKey

LANGUAGE_ID = "java"
Expand Down Expand Up @@ -178,8 +182,7 @@ class JavaBSPCompileFieldSet(BSPCompileFieldSet):

@rule
async def bsp_java_compile_request(
request: JavaBSPCompileFieldSet,
union_membership: UnionMembership,
request: JavaBSPCompileFieldSet, classpath_entry_request: ClasspathEntryRequestFactory
) -> BSPCompileResult:
coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address]))
assert len(coarsened_targets) == 1
Expand All @@ -189,9 +192,7 @@ async def bsp_java_compile_request(
result = await Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=coarsened_target, resolve=resolve
),
classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve),
)
_logger.info(f"java compile result = {result}")
output_digest = EMPTY_DIGEST
Expand Down
12 changes: 8 additions & 4 deletions src/python/pants/backend/java/goals/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from pants.engine.addresses import Addresses
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.engine.unions import UnionRule
from pants.jvm.compile import (
ClasspathEntryRequest,
ClasspathEntryRequestFactory,
FallibleClasspathEntry,
)
from pants.jvm.resolve.key import CoursierResolveKey
from pants.util.logging import LogLevel

Expand All @@ -27,7 +31,7 @@ class JavacCheckRequest(CheckRequest):
@rule(desc="Check javac compilation", level=LogLevel.DEBUG)
async def javac_check(
request: JavacCheckRequest,
union_membership: UnionMembership,
classpath_entry_request: ClasspathEntryRequestFactory,
) -> CheckResults:
coarsened_targets = await Get(
CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets)
Expand All @@ -43,7 +47,7 @@ async def javac_check(
Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve),
classpath_entry_request.for_targets(component=target, resolve=resolve),
)
for target, resolve in zip(coarsened_targets, resolves)
)
Expand Down
15 changes: 7 additions & 8 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pants.bsp.util_rules.lifecycle import BSPLanguageSupport
from pants.bsp.util_rules.targets import BSPBuildTargets, BSPBuildTargetsRequest
from pants.build_graph.address import Address, AddressInput
from pants.core.util_rules.system_binaries import BashBinary, UnzipBinary
from pants.engine.addresses import Addresses
from pants.engine.fs import EMPTY_DIGEST, AddPrefix, CreateDigest, Digest, DigestEntries
from pants.engine.internals.selectors import Get, MultiGet
Expand All @@ -42,7 +41,11 @@
WrappedTarget,
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.jvm.compile import (
ClasspathEntryRequest,
ClasspathEntryRequestFactory,
FallibleClasspathEntry,
)
from pants.jvm.resolve.key import CoursierResolveKey
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmResolveField
Expand Down Expand Up @@ -199,9 +202,7 @@ class ScalaBSPCompileFieldSet(BSPCompileFieldSet):
@rule
async def bsp_scala_compile_request(
request: ScalaBSPCompileFieldSet,
union_membership: UnionMembership,
unzip: UnzipBinary,
bash: BashBinary,
classpath_entry_request: ClasspathEntryRequestFactory,
) -> BSPCompileResult:
coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address]))
assert len(coarsened_targets) == 1
Expand All @@ -211,9 +212,7 @@ async def bsp_scala_compile_request(
result = await Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=coarsened_target, resolve=resolve
),
classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve),
)
_logger.info(f"scala compile result = {result}")
output_digest = EMPTY_DIGEST
Expand Down
12 changes: 8 additions & 4 deletions src/python/pants/backend/scala/goals/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from pants.engine.addresses import Addresses
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry
from pants.engine.unions import UnionRule
from pants.jvm.compile import (
ClasspathEntryRequest,
ClasspathEntryRequestFactory,
FallibleClasspathEntry,
)
from pants.jvm.resolve.key import CoursierResolveKey
from pants.util.logging import LogLevel

Expand All @@ -27,7 +31,7 @@ class ScalacCheckRequest(CheckRequest):
@rule(desc="Check compilation for Scala", level=LogLevel.DEBUG)
async def scalac_check(
request: ScalacCheckRequest,
union_membership: UnionMembership,
classpath_entry_request: ClasspathEntryRequestFactory,
) -> CheckResults:
coarsened_targets = await Get(
CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets)
Expand All @@ -43,7 +47,7 @@ async def scalac_check(
Get(
FallibleClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve),
classpath_entry_request.for_targets(component=target, resolve=resolve),
)
for target, resolve in zip(coarsened_targets, resolves)
)
Expand Down
9 changes: 3 additions & 6 deletions src/python/pants/jvm/classpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from pants.engine.fs import Digest
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import CoarsenedTargets
from pants.engine.unions import UnionMembership
from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest
from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequestFactory
from pants.jvm.resolve.key import CoursierResolveKey

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,7 +68,7 @@ def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]:
@rule
async def classpath(
coarsened_targets: CoarsenedTargets,
union_membership: UnionMembership,
classpath_entry_request: ClasspathEntryRequestFactory,
) -> Classpath:
# Compute a single shared resolve for all of the roots, which will validate that they
# are compatible with one another.
Expand All @@ -80,9 +79,7 @@ async def classpath(
Get(
ClasspathEntry,
ClasspathEntryRequest,
ClasspathEntryRequest.for_targets(
union_membership, component=t, resolve=resolve, root=True
),
classpath_entry_request.for_targets(component=t, resolve=resolve, root=True),
)
for t in coarsened_targets
)
Expand Down
Loading

0 comments on commit 60c19d1

Please sign in to comment.