From 60c19d1a24ab460c0f99a0d2a08411951719e062 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 11 Mar 2022 14:00:33 -0800 Subject: [PATCH] Allow for codegen targets to be used directly by JVM compiler requests (#14751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This amends `ClasspathEntryRequest.for_targets` to accept a codegen source request target and output a `ClasspathEntryRequest` that corresponds to a relevant JVM compilation request. With this approach, it seems that any `GenerateSourcesRequest` with `output` that is compatible with a `ClasspathEntryRequest` will Just Work™. To demonstrate this in action, there's backend registration support for the Java and Scala protobuf generators included in this PR too. The main current limitation is when multiple language backends are enabled for a given codegen source type. A solution for this would address #14041, and realistically needs to take multiple JVM languages into account. --- .../backend/codegen/protobuf/java/register.py | 33 ++++++++ .../backend/codegen/protobuf/java/rules.py | 13 ++- .../codegen/protobuf/scala/register.py | 33 ++++++++ src/python/pants/backend/java/bsp/rules.py | 13 +-- src/python/pants/backend/java/goals/check.py | 12 ++- src/python/pants/backend/scala/bsp/rules.py | 15 ++-- src/python/pants/backend/scala/goals/check.py | 12 ++- src/python/pants/jvm/classpath.py | 9 +- src/python/pants/jvm/compile.py | 65 +++++++++++--- src/python/pants/jvm/compile_test.py | 84 +++++++++++++++---- 10 files changed, 232 insertions(+), 57 deletions(-) create mode 100644 src/python/pants/backend/codegen/protobuf/java/register.py create mode 100644 src/python/pants/backend/codegen/protobuf/scala/register.py diff --git a/src/python/pants/backend/codegen/protobuf/java/register.py b/src/python/pants/backend/codegen/protobuf/java/register.py new file mode 100644 index 00000000000..ef510a37ec1 --- /dev/null +++ b/src/python/pants/backend/codegen/protobuf/java/register.py @@ -0,0 +1,33 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +"""Generate Java sources from Protocol Buffers (Protobufs). + +See https://www.pantsbuild.org/docs/protobuf. +""" + +from pants.backend.codegen import export_codegen_goal +from pants.backend.codegen.protobuf import protobuf_dependency_inference +from pants.backend.codegen.protobuf import tailor as protobuf_tailor +from pants.backend.codegen.protobuf.java.rules import rules as java_rules +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules +from pants.core.util_rules import stripped_source_files + + +def rules(): + return [ + *java_rules(), + *protobuf_dependency_inference.rules(), + *protobuf_tailor.rules(), + *export_codegen_goal.rules(), + *protobuf_target_rules(), + *stripped_source_files.rules(), + ] + + +def target_types(): + return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget] diff --git a/src/python/pants/backend/codegen/protobuf/java/rules.py b/src/python/pants/backend/codegen/protobuf/java/rules.py index f323a412a51..0c4a4e106e0 100644 --- a/src/python/pants/backend/codegen/protobuf/java/rules.py +++ b/src/python/pants/backend/codegen/protobuf/java/rules.py @@ -3,7 +3,11 @@ from pants.backend.codegen.protobuf.protoc import Protoc -from pants.backend.codegen.protobuf.target_types import ProtobufSourceField +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourceField, + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) from pants.backend.java.target_types import JavaSourceField from pants.backend.python.util_rules import pex from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest @@ -28,6 +32,7 @@ TransitiveTargetsRequest, ) from pants.engine.unions import UnionRule +from pants.jvm.target_types import JvmJdkField from pants.source.source_root import SourceRoot, SourceRootRequest from pants.util.logging import LogLevel @@ -116,9 +121,15 @@ async def generate_java_from_protobuf( return GeneratedSources(source_root_restored) +class PrefixedJvmJdkField(JvmJdkField): + alias = "jvm_jdk" + + def rules(): return [ *collect_rules(), *pex.rules(), UnionRule(GenerateSourcesRequest, GenerateJavaFromProtobufRequest), + ProtobufSourceTarget.register_plugin_field(PrefixedJvmJdkField), + ProtobufSourcesGeneratorTarget.register_plugin_field(PrefixedJvmJdkField), ] diff --git a/src/python/pants/backend/codegen/protobuf/scala/register.py b/src/python/pants/backend/codegen/protobuf/scala/register.py new file mode 100644 index 00000000000..dc167171f3d --- /dev/null +++ b/src/python/pants/backend/codegen/protobuf/scala/register.py @@ -0,0 +1,33 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +"""Generate Scala sources from Protocol Buffers (Protobufs). + +See https://www.pantsbuild.org/docs/protobuf. +""" + +from pants.backend.codegen import export_codegen_goal +from pants.backend.codegen.protobuf import protobuf_dependency_inference +from pants.backend.codegen.protobuf import tailor as protobuf_tailor +from pants.backend.codegen.protobuf.scala.rules import rules as scala_rules +from pants.backend.codegen.protobuf.target_types import ( + ProtobufSourcesGeneratorTarget, + ProtobufSourceTarget, +) +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_rules +from pants.core.util_rules import stripped_source_files + + +def rules(): + return [ + *scala_rules(), + *protobuf_dependency_inference.rules(), + *protobuf_tailor.rules(), + *export_codegen_goal.rules(), + *protobuf_target_rules(), + *stripped_source_files.rules(), + ] + + +def target_types(): + return [ProtobufSourcesGeneratorTarget, ProtobufSourceTarget] diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index c5ff1a837ee..58487d88ca5 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -36,7 +36,11 @@ ) from pants.engine.unions import UnionMembership, UnionRule from pants.jvm.bsp.spec import JvmBuildTarget -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey LANGUAGE_ID = "java" @@ -178,8 +182,7 @@ class JavaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_java_compile_request( - request: JavaBSPCompileFieldSet, - union_membership: UnionMembership, + request: JavaBSPCompileFieldSet, classpath_entry_request: ClasspathEntryRequestFactory ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -189,9 +192,7 @@ async def bsp_java_compile_request( result = await Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_target, resolve=resolve - ), + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) _logger.info(f"java compile result = {result}") output_digest = EMPTY_DIGEST diff --git a/src/python/pants/backend/java/goals/check.py b/src/python/pants/backend/java/goals/check.py index 10d62c8fd89..0b7c110fc89 100644 --- a/src/python/pants/backend/java/goals/check.py +++ b/src/python/pants/backend/java/goals/check.py @@ -11,8 +11,12 @@ from pants.engine.addresses import Addresses from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.engine.unions import UnionRule +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +31,7 @@ class JavacCheckRequest(CheckRequest): @rule(desc="Check javac compilation", level=LogLevel.DEBUG) async def javac_check( request: JavacCheckRequest, - union_membership: UnionMembership, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +47,7 @@ async def javac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve), + classpath_entry_request.for_targets(component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index 82d1fb5d408..ea93df044c9 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -29,7 +29,6 @@ from pants.bsp.util_rules.lifecycle import BSPLanguageSupport from pants.bsp.util_rules.targets import BSPBuildTargets, BSPBuildTargetsRequest from pants.build_graph.address import Address, AddressInput -from pants.core.util_rules.system_binaries import BashBinary, UnzipBinary from pants.engine.addresses import Addresses from pants.engine.fs import EMPTY_DIGEST, AddPrefix, CreateDigest, Digest, DigestEntries from pants.engine.internals.selectors import Get, MultiGet @@ -42,7 +41,11 @@ WrappedTarget, ) from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.jvm.subsystems import JvmSubsystem from pants.jvm.target_types import JvmResolveField @@ -199,9 +202,7 @@ class ScalaBSPCompileFieldSet(BSPCompileFieldSet): @rule async def bsp_scala_compile_request( request: ScalaBSPCompileFieldSet, - union_membership: UnionMembership, - unzip: UnzipBinary, - bash: BashBinary, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> BSPCompileResult: coarsened_targets = await Get(CoarsenedTargets, Addresses([request.source.address])) assert len(coarsened_targets) == 1 @@ -211,9 +212,7 @@ async def bsp_scala_compile_request( result = await Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_target, resolve=resolve - ), + classpath_entry_request.for_targets(component=coarsened_target, resolve=resolve), ) _logger.info(f"scala compile result = {result}") output_digest = EMPTY_DIGEST diff --git a/src/python/pants/backend/scala/goals/check.py b/src/python/pants/backend/scala/goals/check.py index 8bb26f72dac..7b011c836d8 100644 --- a/src/python/pants/backend/scala/goals/check.py +++ b/src/python/pants/backend/scala/goals/check.py @@ -11,8 +11,12 @@ from pants.engine.addresses import Addresses from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.compile import ClasspathEntryRequest, FallibleClasspathEntry +from pants.engine.unions import UnionRule +from pants.jvm.compile import ( + ClasspathEntryRequest, + ClasspathEntryRequestFactory, + FallibleClasspathEntry, +) from pants.jvm.resolve.key import CoursierResolveKey from pants.util.logging import LogLevel @@ -27,7 +31,7 @@ class ScalacCheckRequest(CheckRequest): @rule(desc="Check compilation for Scala", level=LogLevel.DEBUG) async def scalac_check( request: ScalacCheckRequest, - union_membership: UnionMembership, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> CheckResults: coarsened_targets = await Get( CoarsenedTargets, Addresses(field_set.address for field_set in request.field_sets) @@ -43,7 +47,7 @@ async def scalac_check( Get( FallibleClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets(union_membership, component=target, resolve=resolve), + classpath_entry_request.for_targets(component=target, resolve=resolve), ) for target, resolve in zip(coarsened_targets, resolves) ) diff --git a/src/python/pants/jvm/classpath.py b/src/python/pants/jvm/classpath.py index 1de2d9f5f8d..2e507e18327 100644 --- a/src/python/pants/jvm/classpath.py +++ b/src/python/pants/jvm/classpath.py @@ -10,8 +10,7 @@ from pants.engine.fs import Digest from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.target import CoarsenedTargets -from pants.engine.unions import UnionMembership -from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest +from pants.jvm.compile import ClasspathEntry, ClasspathEntryRequest, ClasspathEntryRequestFactory from pants.jvm.resolve.key import CoursierResolveKey logger = logging.getLogger(__name__) @@ -69,7 +68,7 @@ def root_immutable_inputs_args(self, *, prefix: str = "") -> Iterator[str]: @rule async def classpath( coarsened_targets: CoarsenedTargets, - union_membership: UnionMembership, + classpath_entry_request: ClasspathEntryRequestFactory, ) -> Classpath: # Compute a single shared resolve for all of the roots, which will validate that they # are compatible with one another. @@ -80,9 +79,7 @@ async def classpath( Get( ClasspathEntry, ClasspathEntryRequest, - ClasspathEntryRequest.for_targets( - union_membership, component=t, resolve=resolve, root=True - ), + classpath_entry_request.for_targets(component=t, resolve=resolve, root=True), ) for t in coarsened_targets ) diff --git a/src/python/pants/jvm/compile.py b/src/python/pants/jvm/compile.py index 7fda142d362..ddf238523da 100644 --- a/src/python/pants/jvm/compile.py +++ b/src/python/pants/jvm/compile.py @@ -6,7 +6,7 @@ import logging import os from abc import ABCMeta -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass from enum import Enum, auto from typing import ClassVar, Iterable, Iterator, Sequence @@ -17,9 +17,16 @@ from pants.engine.internals.selectors import Get, MultiGet from pants.engine.process import FallibleProcessResult from pants.engine.rules import collect_rules, rule -from pants.engine.target import CoarsenedTarget, FieldSet +from pants.engine.target import ( + CoarsenedTarget, + Field, + FieldSet, + GenerateSourcesRequest, + SourcesField, +) from pants.engine.unions import UnionMembership, union from pants.jvm.resolve.key import CoursierResolveKey +from pants.util.frozendict import FrozenDict from pants.util.logging import LogLevel from pants.util.meta import frozen_after_init from pants.util.ordered_set import FrozenOrderedSet @@ -75,9 +82,14 @@ class ClasspathEntryRequest(metaclass=ABCMeta): # True if this request type is only valid at the root of a compile graph. root_only: ClassVar[bool] = False - @staticmethod + +@dataclass(frozen=True) +class ClasspathEntryRequestFactory: + impls: tuple[type[ClasspathEntryRequest], ...] + generator_sources: FrozenDict[type[ClasspathEntryRequest], frozenset[type[SourcesField]]] + def for_targets( - union_membership: UnionMembership, + self, component: CoarsenedTarget, resolve: CoursierResolveKey, *, @@ -92,9 +104,9 @@ def for_targets( compatible = [] partial = [] consume_only = [] - impls = union_membership.get(ClasspathEntryRequest) + impls = self.impls for impl in impls: - classification = ClasspathEntryRequest.classify_impl(impl, component) + classification = self.classify_impl(impl, component) if classification == _ClasspathEntryRequestClassification.INCOMPATIBLE: continue elif classification == _ClasspathEntryRequestClassification.COMPATIBLE: @@ -134,12 +146,16 @@ def for_targets( f"combination of inputs:\n{component.bullet_list()}" ) - @staticmethod def classify_impl( - impl: type[ClasspathEntryRequest], component: CoarsenedTarget + self, impl: type[ClasspathEntryRequest], component: CoarsenedTarget ) -> _ClasspathEntryRequestClassification: targets = component.members - compatible = sum(1 for t in targets for fs in impl.field_sets if fs.is_applicable(t)) + generator_sources = self.generator_sources.get(impl) or frozenset() + + compatible_direct = sum(1 for t in targets for fs in impl.field_sets if fs.is_applicable(t)) + compatible_generated = sum(1 for t in targets for g in generator_sources if t.has_field(g)) + + compatible = compatible_direct + compatible_generated if compatible == 0: return _ClasspathEntryRequestClassification.INCOMPATIBLE if compatible == len(targets): @@ -152,6 +168,31 @@ def classify_impl( return _ClasspathEntryRequestClassification.PARTIAL +@rule +def calculate_jvm_request_types(union_membership: UnionMembership) -> ClasspathEntryRequestFactory: + cpe_impls = union_membership.get(ClasspathEntryRequest) + + impls_by_source: dict[type[Field], type[ClasspathEntryRequest]] = {} + for impl in cpe_impls: + for field_set in impl.field_sets: + for field in field_set.required_fields: + # Assume only one impl per field (normally sound) + # (note that subsequently, we only check for `SourceFields`, so no need to filter) + impls_by_source[field] = impl + + # Classify code generator sources by their CPE impl + sources_by_impl_: dict[type[ClasspathEntryRequest], list[type[SourcesField]]] = defaultdict( + list + ) + + for g in union_membership.get(GenerateSourcesRequest): + if g.output in impls_by_source: + sources_by_impl_[impls_by_source[g.output]].append(g.input) + sources_by_impl = FrozenDict((key, frozenset(value)) for key, value in sources_by_impl_.items()) + + return ClasspathEntryRequestFactory(tuple(cpe_impls), sources_by_impl) + + @frozen_after_init @dataclass(unsafe_hash=True) class ClasspathEntry: @@ -341,7 +382,7 @@ def required_classfiles(fallible_result: FallibleClasspathEntry) -> ClasspathEnt @rule def classpath_dependency_requests( - union_membership: UnionMembership, request: ClasspathDependenciesRequest + classpath_entry_request: ClasspathEntryRequestFactory, request: ClasspathDependenciesRequest ) -> ClasspathEntryRequests: def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: if len(coarsened_dep.members) == 1: @@ -351,8 +392,8 @@ def ignore_because_generated(coarsened_dep: CoarsenedTarget) -> bool: return us.spec_path == them.spec_path and us.target_name == them.target_name return ClasspathEntryRequests( - ClasspathEntryRequest.for_targets( - union_membership, component=coarsened_dep, resolve=request.request.resolve + classpath_entry_request.for_targets( + component=coarsened_dep, resolve=request.request.resolve ) for coarsened_dep in request.request.component.dependencies if not request.ignore_generated or not ignore_because_generated(coarsened_dep) diff --git a/src/python/pants/jvm/compile_test.py b/src/python/pants/jvm/compile_test.py index 25902069add..66b9b9b3296 100644 --- a/src/python/pants/jvm/compile_test.py +++ b/src/python/pants/jvm/compile_test.py @@ -13,11 +13,15 @@ import textwrap from textwrap import dedent -from typing import Sequence, cast +from typing import Sequence, Type, cast import chevron import pytest +from pants.backend.codegen.protobuf.java.rules import GenerateJavaFromProtobufRequest +from pants.backend.codegen.protobuf.java.rules import rules as protobuf_rules +from pants.backend.codegen.protobuf.target_types import ProtobufSourceField, ProtobufSourceTarget +from pants.backend.codegen.protobuf.target_types import rules as protobuf_target_types_rules from pants.backend.java.compile.javac import CompileJavaSourceRequest from pants.backend.java.compile.javac import rules as javac_rules from pants.backend.java.dependency_inference.rules import rules as java_dep_inf_rules @@ -33,17 +37,25 @@ from pants.backend.scala.target_types import ScalaSourcesGeneratorTarget from pants.backend.scala.target_types import rules as scala_target_types_rules from pants.build_graph.address import Address -from pants.core.util_rules import config_files, source_files +from pants.core.util_rules import config_files, source_files, stripped_source_files from pants.core.util_rules.external_tool import rules as external_tool_rules from pants.engine.addresses import Addresses from pants.engine.fs import EMPTY_DIGEST from pants.engine.internals.native_engine import FileDigest -from pants.engine.target import CoarsenedTarget, Target, UnexpandedTargets -from pants.engine.unions import UnionMembership +from pants.engine.target import ( + CoarsenedTarget, + GeneratedSources, + HydratedSources, + HydrateSourcesRequest, + SourcesField, + Target, + UnexpandedTargets, +) from pants.jvm import classpath, jdk_rules, testutil from pants.jvm.classpath import Classpath from pants.jvm.compile import ( ClasspathEntryRequest, + ClasspathEntryRequestFactory, ClasspathSourceAmbiguity, ClasspathSourceMissing, ) @@ -66,6 +78,7 @@ ) from pants.jvm.util_rules import rules as util_rules from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner +from pants.util.frozendict import FrozenDict DEFAULT_LOCKFILE = TestCoursierWrapper( CoursierResolvedLockfile( @@ -125,11 +138,21 @@ def rule_runner() -> RuleRunner: *java_target_types_rules(), *util_rules(), *testutil.rules(), + *protobuf_rules(), + *stripped_source_files.rules(), + *protobuf_target_types_rules(), QueryRule(Classpath, (Addresses,)), QueryRule(RenderedClasspath, (Addresses,)), QueryRule(UnexpandedTargets, (Addresses,)), + QueryRule(HydratedSources, [HydrateSourcesRequest]), + QueryRule(GeneratedSources, [GenerateJavaFromProtobufRequest]), + ], + target_types=[ + JavaSourcesGeneratorTarget, + JvmArtifactTarget, + ProtobufSourceTarget, + ScalaSourcesGeneratorTarget, ], - target_types=[ScalaSourcesGeneratorTarget, JavaSourcesGeneratorTarget, JvmArtifactTarget], ) rule_runner.set_options(args=[], env_inherit=PYTHON_BOOTSTRAP_ENV) return rule_runner @@ -182,6 +205,22 @@ def main(args: Array[String]): Unit = { ) +def proto_source() -> str: + return dedent( + """\ + syntax = "proto3"; + + package dir1; + + message Person { + string name = 1; + int32 id = 2; + string email = 3; + } + """ + ) + + class CompileMockSourceRequest(ClasspathEntryRequest): field_sets = (JavaFieldSet, JavaGeneratorFieldSet) @@ -191,9 +230,12 @@ def test_request_classification(rule_runner: RuleRunner) -> None: def classify( targets: Sequence[Target], members: Sequence[type[ClasspathEntryRequest]], + generators: FrozenDict[type[ClasspathEntryRequest], frozenset[type[SourcesField]]], ) -> tuple[type[ClasspathEntryRequest], type[ClasspathEntryRequest] | None]: - req = ClasspathEntryRequest.for_targets( - UnionMembership({ClasspathEntryRequest: members}), + + factory = ClasspathEntryRequestFactory(tuple(members), generators) + + req = factory.for_targets( CoarsenedTarget(targets, ()), CoursierResolveKey("example", "path", EMPTY_DIGEST), ) @@ -206,13 +248,15 @@ def classify( scala_sources(name='scala') java_sources(name='java') jvm_artifact(name='jvm_artifact', group='ex', artifact='ex', version='0.0.0') + protobuf_source(name='proto', source="f.proto") {DEFAULT_SCALA_LIBRARY_TARGET} """ ), + "f.proto": proto_source(), "3rdparty/jvm/default.lock": DEFAULT_LOCKFILE, } ) - scala, java, jvm_artifact = rule_runner.request( + scala, java, jvm_artifact, proto = rule_runner.request( UnexpandedTargets, [ Addresses( @@ -220,33 +264,41 @@ def classify( Address("", target_name="scala"), Address("", target_name="java"), Address("", target_name="jvm_artifact"), + Address("", target_name="proto"), ] ) ], ) all_members = [CompileJavaSourceRequest, CompileScalaSourceRequest, CoursierFetchRequest] + generators = FrozenDict( + { + CompileJavaSourceRequest: frozenset([cast(Type[SourcesField], ProtobufSourceField)]), + CompileScalaSourceRequest: frozenset(), + } + ) # Fully compatible. - assert (CompileJavaSourceRequest, None) == classify([java], all_members) - assert (CompileScalaSourceRequest, None) == classify([scala], all_members) - assert (CoursierFetchRequest, None) == classify([jvm_artifact], all_members) + assert (CompileJavaSourceRequest, None) == classify([java], all_members, generators) + assert (CompileScalaSourceRequest, None) == classify([scala], all_members, generators) + assert (CoursierFetchRequest, None) == classify([jvm_artifact], all_members, generators) + assert (CompileJavaSourceRequest, None) == classify([proto], all_members, generators) # Partially compatible. assert (CompileJavaSourceRequest, CompileScalaSourceRequest) == classify( - [java, scala], all_members + [java, scala], all_members, generators ) with pytest.raises(ClasspathSourceMissing): - classify([java, jvm_artifact], all_members) + classify([java, jvm_artifact], all_members, generators) # None compatible. with pytest.raises(ClasspathSourceMissing): - classify([java], []) + classify([java], [], generators) with pytest.raises(ClasspathSourceMissing): - classify([scala, java, jvm_artifact], all_members) + classify([scala, java, jvm_artifact], all_members, generators) # Too many compatible. with pytest.raises(ClasspathSourceAmbiguity): - classify([java], [CompileJavaSourceRequest, CompileMockSourceRequest]) + classify([java], [CompileJavaSourceRequest, CompileMockSourceRequest], generators) @maybe_skip_jdk_test