From e628a930e62853203f5729b536278749ce7b4ba4 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Mon, 18 Jul 2022 16:58:27 -0700 Subject: [PATCH] Generically compute dynamic defaults for `Field`s (#16206) As discussed on #16175, we don't currently consume the "dynamic" defaults of field values for the purposes of `parametrize`. That is at least partially because there is no generic way to do so: a `Field` has no way to declare a dynamic default currently, because `Field`s cannot declare a dependency `@rule_helper` to compute their value (...yet? see https://github.com/pantsbuild/pants/issues/12934#issuecomment-1111608974). This change adds a mechanism for generically declaring the default value of a `Field`. This is definitely not the most ergonomic API: over the next few versions, many dynamic `Field` defaults will hopefully move to `__defaults__`. And https://github.com/pantsbuild/pants/issues/12934#issuecomment-1111608974 will hopefully allow for significantly cleaning up those that remain. Fixes #16175. [ci skip-rust] [ci skip-build-wheels] --- src/python/pants/backend/java/bsp/rules.py | 19 +-- src/python/pants/backend/java/target_types.py | 6 +- .../backend/python/target_types_rules.py | 20 +++ src/python/pants/backend/scala/bsp/rules.py | 20 +-- .../pants/backend/scala/target_types.py | 2 + src/python/pants/bsp/util_rules/targets.py | 56 +++------ .../pants/bsp/util_rules/targets_test.py | 2 + src/python/pants/engine/internals/graph.py | 25 +++- .../pants/engine/internals/graph_test.py | 34 ++++-- .../pants/engine/internals/parametrize.py | 28 +++-- .../engine/internals/parametrize_test.py | 114 ++++++++++-------- src/python/pants/engine/target.py | 66 +++++++++- src/python/pants/jvm/target_types.py | 28 +++++ 13 files changed, 274 insertions(+), 146 deletions(-) diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index 7f1f99d887e..330bdc6e538 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -15,8 +15,6 @@ from pants.bsp.util_rules.targets import ( BSPBuildTargetsMetadataRequest, BSPBuildTargetsMetadataResult, - BSPResolveFieldFactoryRequest, - BSPResolveFieldFactoryResult, ) from pants.engine.addresses import Addresses from pants.engine.fs import CreateDigest, DigestEntries @@ -31,7 +29,6 @@ FallibleClasspathEntry, ) from pants.jvm.resolve.key import CoursierResolveKey -from pants.jvm.subsystems import JvmSubsystem from pants.jvm.target_types import JvmResolveField LANGUAGE_ID = "java" @@ -52,28 +49,15 @@ class JavaMetadataFieldSet(FieldSet): resolve: JvmResolveField -class JavaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest): - resolve_prefix = "jvm" - - class JavaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest): language_id = LANGUAGE_ID can_merge_metadata_from = () field_set_type = JavaMetadataFieldSet + resolve_prefix = "jvm" resolve_field = JvmResolveField -@rule -def bsp_resolve_field_factory( - request: JavaBSPResolveFieldFactoryRequest, - jvm: JvmSubsystem, -) -> BSPResolveFieldFactoryResult: - return BSPResolveFieldFactoryResult( - lambda target: target.get(JvmResolveField).normalized_value(jvm) - ) - - @rule async def bsp_resolve_java_metadata( _: JavaBSPBuildTargetsMetadataRequest, @@ -195,7 +179,6 @@ def rules(): return ( *collect_rules(), UnionRule(BSPLanguageSupport, JavaBSPLanguageSupport), - UnionRule(BSPResolveFieldFactoryRequest, JavaBSPResolveFieldFactoryRequest), UnionRule(BSPBuildTargetsMetadataRequest, JavaBSPBuildTargetsMetadataRequest), UnionRule(BSPHandlerMapping, JavacOptionsHandlerMapping), UnionRule(BSPCompileRequest, JavaBSPCompileRequest), diff --git a/src/python/pants/backend/java/target_types.py b/src/python/pants/backend/java/target_types.py index 5495ace3222..29024249748 100644 --- a/src/python/pants/backend/java/target_types.py +++ b/src/python/pants/backend/java/target_types.py @@ -15,6 +15,7 @@ Target, TargetFilesGenerator, ) +from pants.jvm import target_types as jvm_target_types from pants.jvm.target_types import ( JunitTestSourceField, JvmJdkField, @@ -128,4 +129,7 @@ class JavaSourcesGeneratorTarget(TargetFilesGenerator): def rules(): - return collect_rules() + return [ + *collect_rules(), + *jvm_target_types.rules(), + ] diff --git a/src/python/pants/backend/python/target_types_rules.py b/src/python/pants/backend/python/target_types_rules.py index ef9c3c685af..ec953ff8c01 100644 --- a/src/python/pants/backend/python/target_types_rules.py +++ b/src/python/pants/backend/python/target_types_rules.py @@ -51,6 +51,8 @@ Dependencies, DependenciesRequest, ExplicitlyProvidedDependencies, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, GeneratedTargets, GenerateTargetsRequest, @@ -441,6 +443,23 @@ async def inject_python_distribution_dependencies( ) +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# ----------------------------------------------------------------------------------------------- + + +class PythonResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = PythonResolveField + + +@rule +def python_resolve_field_default_factory( + request: PythonResolveFieldDefaultFactoryRequest, + python_setup: PythonSetup, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.normalized_value(python_setup)) + + # ----------------------------------------------------------------------------------------------- # Dependency validation # ----------------------------------------------------------------------------------------------- @@ -509,6 +528,7 @@ def rules(): return ( *collect_rules(), *import_rules(), + UnionRule(FieldDefaultFactoryRequest, PythonResolveFieldDefaultFactoryRequest), UnionRule(TargetFilesGeneratorSettingsRequest, PythonFilesGeneratorSettingsRequest), UnionRule(GenerateTargetsRequest, GenerateTargetsFromPexBinaries), UnionRule(InjectDependenciesRequest, InjectPexBinaryEntryPointDependency), diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index ab5025f7748..1a67250cfb0 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -31,8 +31,6 @@ BSPBuildTargetsMetadataResult, BSPDependencyModulesRequest, BSPDependencyModulesResult, - BSPResolveFieldFactoryRequest, - BSPResolveFieldFactoryResult, ) from pants.engine.addresses import Addresses from pants.engine.fs import ( @@ -91,15 +89,14 @@ class ScalaMetadataFieldSet(FieldSet): resolve: JvmResolveField -class ScalaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest): - resolve_prefix = "jvm" - - class ScalaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest): language_id = LANGUAGE_ID can_merge_metadata_from = ("java",) field_set_type = ScalaMetadataFieldSet + resolve_prefix = "jvm" + resolve_field = JvmResolveField + @dataclass(frozen=True) class ResolveScalaBSPBuildTargetRequest: @@ -154,16 +151,6 @@ async def materialize_scala_runtime_jars( return MaterializeScalaRuntimeJarsResult(materialized_classpath) -@rule -def bsp_resolve_field_factory( - request: ScalaBSPResolveFieldFactoryRequest, - jvm: JvmSubsystem, -) -> BSPResolveFieldFactoryResult: - return BSPResolveFieldFactoryResult( - lambda target: target.get(JvmResolveField).normalized_value(jvm) - ) - - @rule async def bsp_resolve_scala_metadata( request: ScalaBSPBuildTargetsMetadataRequest, @@ -470,7 +457,6 @@ def rules(): *collect_rules(), UnionRule(BSPLanguageSupport, ScalaBSPLanguageSupport), UnionRule(BSPBuildTargetsMetadataRequest, ScalaBSPBuildTargetsMetadataRequest), - UnionRule(BSPResolveFieldFactoryRequest, ScalaBSPResolveFieldFactoryRequest), UnionRule(BSPHandlerMapping, ScalacOptionsHandlerMapping), UnionRule(BSPHandlerMapping, ScalaMainClassesHandlerMapping), UnionRule(BSPHandlerMapping, ScalaTestClassesHandlerMapping), diff --git a/src/python/pants/backend/scala/target_types.py b/src/python/pants/backend/scala/target_types.py index 7f793ec7367..7bd4f1a279d 100644 --- a/src/python/pants/backend/scala/target_types.py +++ b/src/python/pants/backend/scala/target_types.py @@ -20,6 +20,7 @@ TargetFilesGeneratorSettingsRequest, ) from pants.engine.unions import UnionRule +from pants.jvm import target_types as jvm_target_types from pants.jvm.target_types import ( JunitTestSourceField, JvmJdkField, @@ -276,5 +277,6 @@ class ScalacPluginTarget(Target): def rules(): return ( *collect_rules(), + *jvm_target_types.rules(), UnionRule(TargetFilesGeneratorSettingsRequest, ScalaSettingsRequest), ) diff --git a/src/python/pants/bsp/util_rules/targets.py b/src/python/pants/bsp/util_rules/targets.py index c5282fcd072..1d601360524 100644 --- a/src/python/pants/bsp/util_rules/targets.py +++ b/src/python/pants/bsp/util_rules/targets.py @@ -10,7 +10,6 @@ from typing import ClassVar, Generic, Sequence, Type, TypeVar import toml -from typing_extensions import Protocol from pants.base.build_root import BuildRoot from pants.base.glob_match_error_behavior import GlobMatchErrorBehavior @@ -40,11 +39,12 @@ from pants.engine.internals.selectors import Get, MultiGet from pants.engine.rules import _uncacheable_rule, collect_rules, rule from pants.engine.target import ( + Field, + FieldDefaults, FieldSet, SourcesField, SourcesPaths, SourcesPathsRequest, - Target, Targets, ) from pants.engine.unions import UnionMembership, UnionRule, union @@ -58,38 +58,6 @@ _FS = TypeVar("_FS", bound=FieldSet) -@union -@dataclass(frozen=True) -class BSPResolveFieldFactoryRequest(Generic[_FS]): - """Requests an implementation of `BSPResolveFieldFactory` which can filter resolve fields. - - TODO: This is to work around the fact that Field value defaulting cannot have arbitrary - subsystem requirements, and so `JvmResolveField` and `PythonResolveField` have methods - which compute the true value of the field given a subsytem argument. Consumers need to - be type aware, and `@rules` cannot have dynamic requirements. - - See https://github.com/pantsbuild/pants/issues/12934 about potentially allowing unions - (including Field registrations) to have `@rule_helper` methods, which would allow the - computation of an AsyncFields to directly require a subsystem. - """ - - resolve_prefix: ClassVar[str] - - -# TODO: Workaround for https://github.com/python/mypy/issues/5485, because we cannot directly use -# a Callable. -class _ResolveFieldFactory(Protocol): - def __call__(self, target: Target) -> str | None: - pass - - -@dataclass(frozen=True) -class BSPResolveFieldFactoryResult: - """Computes the resolve field value for a Target, if applicable.""" - - resolve_field_value: _ResolveFieldFactory - - @union @dataclass(frozen=True) class BSPBuildTargetsMetadataRequest(Generic[_FS]): @@ -99,6 +67,9 @@ class BSPBuildTargetsMetadataRequest(Generic[_FS]): can_merge_metadata_from: ClassVar[tuple[str, ...]] field_set_type: ClassVar[Type[_FS]] + resolve_prefix: ClassVar[str] + resolve_field: ClassVar[type[Field]] + field_sets: tuple[_FS, ...] @@ -255,6 +226,7 @@ async def resolve_bsp_build_target_identifier( async def resolve_bsp_build_target_addresses( bsp_target: BSPBuildTargetInternal, union_membership: UnionMembership, + field_defaults: FieldDefaults, ) -> Targets: targets = await Get(Targets, AddressSpecs, bsp_target.specs.address_specs) if bsp_target.definition.resolve_filter is None: @@ -268,17 +240,19 @@ async def resolve_bsp_build_target_addresses( f"prefix like `$lang:$filter`, but the configured value: `{resolve_filter}` did not." ) - # TODO: See `BSPResolveFieldFactoryRequest` re: this awkwardness. - factories = await MultiGet( - Get(BSPResolveFieldFactoryResult, BSPResolveFieldFactoryRequest, request()) - for request in union_membership.get(BSPResolveFieldFactoryRequest) - if request.resolve_prefix == resolve_prefix - ) + resolve_fields = { + impl.resolve_field + for impl in union_membership.get(BSPBuildTargetsMetadataRequest) + if impl.resolve_prefix == resolve_prefix + } return Targets( t for t in targets - if any((factory.resolve_field_value)(t) == resolve_value for factory in factories) + if any( + t.has_field(field) and field_defaults.value_or_default(t[field]) == resolve_value + for field in resolve_fields + ) ) diff --git a/src/python/pants/bsp/util_rules/targets_test.py b/src/python/pants/bsp/util_rules/targets_test.py index e3fe530f192..6869c6be8b2 100644 --- a/src/python/pants/bsp/util_rules/targets_test.py +++ b/src/python/pants/bsp/util_rules/targets_test.py @@ -4,6 +4,7 @@ import pytest +from pants.backend.java import target_types from pants.backend.java.bsp import rules as java_bsp_rules from pants.backend.java.compile import javac from pants.backend.java.target_types import JavaSourceTarget @@ -29,6 +30,7 @@ def rule_runner() -> RuleRunner: *jvm_tool.rules(), *jvm_util_rules.rules(), *jdk_rules.rules(), + *target_types.rules(), QueryRule(BSPBuildTargets, ()), QueryRule(Targets, [BuildTargetIdentifier]), ], diff --git a/src/python/pants/engine/internals/graph.py b/src/python/pants/engine/internals/graph.py index 7999517f520..f0dcf9c08f9 100644 --- a/src/python/pants/engine/internals/graph.py +++ b/src/python/pants/engine/internals/graph.py @@ -57,6 +57,9 @@ DependenciesRequest, ExplicitlyProvidedDependencies, Field, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, + FieldDefaults, FieldSet, FieldSetsPerTarget, FieldSetsPerTargetRequest, @@ -1066,6 +1069,7 @@ async def resolve_dependencies( target_types_to_generate_requests: TargetTypesToGenerateTargetsRequests, union_membership: UnionMembership, subproject_roots: SubprojectRoots, + field_defaults: FieldDefaults, ) -> Addresses: wrapped_tgt, explicitly_provided = await MultiGet( Get(WrappedTarget, Address, request.field.address), @@ -1119,7 +1123,7 @@ async def resolve_dependencies( ) explicitly_provided_includes = [ - parametrizations.get_subset(address, tgt).address + parametrizations.get_subset(address, tgt, field_defaults).address for address, parametrizations in zip( explicitly_provided_includes, explicit_dependency_parametrizations ) @@ -1200,6 +1204,25 @@ async def resolve_unparsed_address_inputs( return Addresses(addresses) +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# ----------------------------------------------------------------------------------------------- + + +@rule +async def field_defaults(union_membership: UnionMembership) -> FieldDefaults: + requests = list(union_membership.get(FieldDefaultFactoryRequest)) + factories = await MultiGet( + Get(FieldDefaultFactoryResult, FieldDefaultFactoryRequest, impl()) for impl in requests + ) + return FieldDefaults( + FrozenDict( + (request.field_type, factory.default_factory) + for request, factory in zip(requests, factories) + ) + ) + + # ----------------------------------------------------------------------------------------------- # Find applicable field sets # ----------------------------------------------------------------------------------------------- diff --git a/src/python/pants/engine/internals/graph_test.py b/src/python/pants/engine/internals/graph_test.py index a2d97a9b3e0..d907658f141 100644 --- a/src/python/pants/engine/internals/graph_test.py +++ b/src/python/pants/engine/internals/graph_test.py @@ -58,6 +58,8 @@ DependenciesRequest, ExplicitlyProvidedDependencies, Field, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, FilteredTargets, GeneratedSources, @@ -106,6 +108,25 @@ class SpecialCasedDeps2(SpecialCasedDependencies): alias = "special_cased_deps2" +class ResolveField(StringField, AsyncFieldMixin): + alias = "resolve" + default = None + + +_DEFAULT_RESOLVE = "default_test_resolve" + + +class ResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = ResolveField + + +@rule +def resolve_field_default_factory( + request: ResolveFieldDefaultFactoryRequest, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.value or _DEFAULT_RESOLVE) + + class MockTarget(Target): alias = "target" core_fields = ( @@ -119,10 +140,6 @@ class MockTarget(Target): deprecated_alias_removal_version = "9.9.9.dev0" -class ResolveField(StringField, AsyncFieldMixin): - alias = "resolve" - - class MockGeneratedTarget(Target): alias = "generated" core_fields = (MockDependencies, Tags, SingleSourceField, ResolveField) @@ -999,6 +1016,8 @@ def generated_targets_rule_runner() -> RuleRunner: QueryRule(Addresses, [Specs]), QueryRule(_DependencyMapping, [_DependencyMappingRequest]), QueryRule(_TargetParametrizations, [Address]), + UnionRule(FieldDefaultFactoryRequest, ResolveFieldDefaultFactoryRequest), + resolve_field_default_factory, ], target_types=[MockTargetGenerator, MockGeneratedTarget], objects={"parametrize": Parametrize}, @@ -1276,12 +1295,11 @@ def test_parametrize_partial_atom_to_atom(generated_targets_rule_runner: RuleRun """\ generated( name='t1', - resolve=parametrize('a', 'b'), + resolve=parametrize('default_test_resolve', 'b'), source='f1.ext', ) generated( name='t2', - resolve='a', source='f2.ext', dependencies=[':t1'], ) @@ -1289,9 +1307,9 @@ def test_parametrize_partial_atom_to_atom(generated_targets_rule_runner: RuleRun ), ["f1.ext", "f2.ext"], expected_dependencies={ - "demo:t1@resolve=a": set(), + "demo:t1@resolve=default_test_resolve": set(), "demo:t1@resolve=b": set(), - "demo:t2": {"demo:t1@resolve=a"}, + "demo:t2": {"demo:t1@resolve=default_test_resolve"}, }, ) diff --git a/src/python/pants/engine/internals/parametrize.py b/src/python/pants/engine/internals/parametrize.py index 0e666045e82..67ddaa1a297 100644 --- a/src/python/pants/engine/internals/parametrize.py +++ b/src/python/pants/engine/internals/parametrize.py @@ -10,7 +10,7 @@ from pants.build_graph.address import BANNED_CHARS_IN_PARAMETERS from pants.engine.addresses import Address from pants.engine.collection import Collection -from pants.engine.target import Field, Target +from pants.engine.target import Field, FieldDefaults, Target from pants.util.frozendict import FrozenDict from pants.util.meta import frozen_after_init from pants.util.strutil import bullet_list @@ -200,7 +200,9 @@ def get_all_superset_targets(self, address: Address) -> Iterator[Address]: if address.is_parametrized_subset_of(parametrized_tgt.address): yield parametrized_tgt.address - def get_subset(self, address: Address, consumer: Target) -> Target: + def get_subset( + self, address: Address, consumer: Target, field_defaults: FieldDefaults + ) -> Target: """Find the Target with the given Address, or with fields matching the given consumer.""" # Check for exact matches. instance = self.get(address) @@ -214,9 +216,9 @@ def remaining_fields_match(candidate: Target) -> bool: } return all( _concrete_fields_are_equivalent( + field_defaults, consumer=consumer, - candidate_field_type=field_type, - candidate_field_value=field.value, + candidate_field=field, ) for field_type, field in candidate.field_values.items() if field_type.alias in unspecified_param_field_names @@ -277,11 +279,17 @@ def _bare_address_error(self, address) -> ValueError: def _concrete_fields_are_equivalent( - *, consumer: Target, candidate_field_value: Any, candidate_field_type: type[Field] + field_defaults: FieldDefaults, *, consumer: Target, candidate_field: Field ) -> bool: - # TODO(#16175): Does not account for the computed default values of Fields. + candidate_field_type = type(candidate_field) + candidate_field_value = field_defaults.value_or_default(candidate_field) + if consumer.has_field(candidate_field_type): - return cast(bool, consumer[candidate_field_type].value == candidate_field_value) + return cast( + bool, + field_defaults.value_or_default(consumer[candidate_field_type]) + == candidate_field_value, + ) # Else, see if the consumer has a field that is a superclass of `candidate_field_value`, to # handle https://github.com/pantsbuild/pants/issues/16190. This is only safe because we are # confident that both `candidate_field_type` and the fields from `consumer` are _concrete_, @@ -290,10 +298,12 @@ def _concrete_fields_are_equivalent( ( consumer_field for consumer_field in consumer.field_types - if issubclass(candidate_field_type, consumer_field) + if isinstance(candidate_field, consumer_field) ), None, ) if superclass is None: return False - return cast(bool, consumer[superclass].value == candidate_field_value) + return cast( + bool, field_defaults.value_or_default(consumer[superclass]) == candidate_field_value + ) diff --git a/src/python/pants/engine/internals/parametrize_test.py b/src/python/pants/engine/internals/parametrize_test.py index 4c2459d8df2..a4da6d2776e 100644 --- a/src/python/pants/engine/internals/parametrize_test.py +++ b/src/python/pants/engine/internals/parametrize_test.py @@ -15,7 +15,7 @@ _TargetParametrization, _TargetParametrizations, ) -from pants.engine.target import Field, Target +from pants.engine.target import Field, FieldDefaults, Target from pants.util.frozendict import FrozenDict @@ -144,10 +144,12 @@ def assert_gets(addr: Address, expected: set[Address]) -> None: def test_concrete_fields_are_equivalent() -> None: class ParentField(Field): alias = "parent" + default = None help = "foo" class ChildField(ParentField): alias = "child" + default = None help = "foo" class UnrelatedField(Field): @@ -164,66 +166,78 @@ class ChildTarget(Target): help = "foo" core_fields = (ChildField,) + # Validate literal value matches. + empty_defaults = FieldDefaults(FrozenDict()) + unused_addr = Address("unused") parent_tgt = ParentTarget({"parent": "val"}, Address("parent")) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ParentField, candidate_field_value="val" - ) - is True + child_tgt = ChildTarget({"child": "val"}, Address("child")) + + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=ParentField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ParentField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=parent_tgt, + candidate_field=ParentField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ChildField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=ChildField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ChildField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=parent_tgt, + candidate_field=ChildField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=UnrelatedField, candidate_field_value="val" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=UnrelatedField("val", unused_addr) ) - child_tgt = ChildTarget({"child": "val"}, Address("child")) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ParentField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ParentField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ParentField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=child_tgt, + candidate_field=ParentField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ChildField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ChildField("val", unused_addr) + ) + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ChildField("different", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ChildField, candidate_field_value="different" + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=UnrelatedField("val", unused_addr) + ) + + # Validate field defaulting. + parent_field_defaults = FieldDefaults( + FrozenDict( + { + ParentField: lambda f: f.value or "val", + } ) - is False ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=UnrelatedField, candidate_field_value="val" + child_field_defaults = FieldDefaults( + FrozenDict( + { + ChildField: lambda f: f.value or "val", + } ) - is False + ) + assert _concrete_fields_are_equivalent( + parent_field_defaults, consumer=child_tgt, candidate_field=ParentField(None, unused_addr) + ) + assert _concrete_fields_are_equivalent( + parent_field_defaults, + consumer=ParentTarget({}, Address("parent")), + candidate_field=ChildField("val", unused_addr), + ) + assert _concrete_fields_are_equivalent( + child_field_defaults, consumer=parent_tgt, candidate_field=ChildField(None, unused_addr) + ) + assert _concrete_fields_are_equivalent( + child_field_defaults, + consumer=ChildTarget({}, Address("child")), + candidate_field=ParentField("val", unused_addr), ) diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 0d610999c61..6fd8a50a22c 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -34,7 +34,7 @@ get_type_hints, ) -from typing_extensions import final +from typing_extensions import Protocol, final from pants.base.deprecated import warn_or_error from pants.engine.addresses import Address, Addresses, UnparsedAddressInputs, assert_single_address @@ -262,10 +262,74 @@ def __eq__(self, other: Union[Any, AsyncFieldMixin]) -> bool: ) +@union +@dataclass(frozen=True) +class FieldDefaultFactoryRequest: + """Registers a dynamic default for a Field. + + See `FieldDefaults`. + """ + + field_type: ClassVar[type[Field]] + + +# TODO: Workaround for https://github.com/python/mypy/issues/5485, because we cannot directly use +# a Callable. +class FieldDefaultFactory(Protocol): + def __call__(self, field: Field) -> Any: + pass + + +@dataclass(frozen=True) +class FieldDefaultFactoryResult: + """A wrapper for a function which computes the default value of a Field.""" + + default_factory: FieldDefaultFactory + + +@dataclass(frozen=True) +class FieldDefaults: + """Generic Field default values. To install a default, see `FieldDefaultFactoryRequest`. + + TODO: This is to work around the fact that Field value defaulting cannot have arbitrary + subsystem requirements, and so e.g. `JvmResolveField` and `PythonResolveField` have methods + which compute the true value of the field given a subsytem argument. Consumers need to + be type aware, and `@rules` cannot have dynamic requirements. + + Additionally, `__defaults__` should mean that computed default Field values should become + more rare: i.e. `JvmResolveField` and `PythonResolveField` could potentially move to + hardcoded default values which users override with `__defaults__` if they'd like to change + the default resolve names. + + See https://github.com/pantsbuild/pants/issues/12934 about potentially allowing unions + (including Field registrations) to have `@rule_helper` methods, which would allow the + computation of an AsyncField to directly require a subsystem. + """ + + _factories: FrozenDict[type[Field], FieldDefaultFactory] + + @memoized_method + def factory(self, field_type: type[Field]) -> FieldDefaultFactory: + """Looks up a Field default factory in a subclass-aware way.""" + factory = self._factories.get(field_type, None) + if factory is not None: + return factory + + for ft, factory in self._factories.items(): + if issubclass(field_type, ft): + return factory + + return lambda f: f.value + + def value_or_default(self, field: Field) -> Any: + return (self.factory(type(field)))(field) + + # ----------------------------------------------------------------------------------------------- # Core Target abstractions # ----------------------------------------------------------------------------------------------- + # NB: This TypeVar is what allows `Target.get()` to properly work with MyPy so that MyPy knows # the precise Field returned. _F = TypeVar("_F", bound=Field) diff --git a/src/python/pants/jvm/target_types.py b/src/python/pants/jvm/target_types.py index 50f042a166f..6a2d549709c 100644 --- a/src/python/pants/jvm/target_types.py +++ b/src/python/pants/jvm/target_types.py @@ -10,10 +10,13 @@ from pants.core.goals.package import OutputPathField from pants.core.goals.run import RestartableField from pants.engine.addresses import Address +from pants.engine.rules import collect_rules, rule from pants.engine.target import ( COMMON_TARGET_FIELDS, AsyncFieldMixin, Dependencies, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, InvalidFieldException, InvalidTargetException, @@ -24,6 +27,7 @@ StringSequenceField, Target, ) +from pants.engine.unions import UnionRule from pants.jvm.subsystems import JvmSubsystem from pants.util.docutil import git_url from pants.util.strutil import softwrap @@ -378,3 +382,27 @@ class JvmWarTarget(Target): deploys in Java Servlet containers. """ ) + + +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# -----------------------------------------------------------------------------------------------# + + +class JvmResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = JvmResolveField + + +@rule +def jvm_resolve_field_default_factory( + request: JvmResolveFieldDefaultFactoryRequest, + jvm: JvmSubsystem, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.normalized_value(jvm)) + + +def rules(): + return [ + *collect_rules(), + UnionRule(FieldDefaultFactoryRequest, JvmResolveFieldDefaultFactoryRequest), + ]