diff --git a/src/python/pants/backend/shell/dependency_inference.py b/src/python/pants/backend/shell/dependency_inference.py index e877ca40208..3874ca7a552 100644 --- a/src/python/pants/backend/shell/dependency_inference.py +++ b/src/python/pants/backend/shell/dependency_inference.py @@ -16,7 +16,7 @@ from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest from pants.engine.addresses import Address from pants.engine.collection import DeduplicatedCollection -from pants.engine.fs import Digest, DigestSubset, MergeDigests, PathGlobs +from pants.engine.fs import Digest, MergeDigests from pants.engine.platform import Platform from pants.engine.process import FallibleProcessResult, Process, ProcessCacheScope from pants.engine.rules import Get, MultiGet, collect_rules, rule @@ -68,11 +68,12 @@ async def map_shell_files(tgts: AllShellTargets) -> ShellMapping: files_to_addresses: dict[str, Address] = {} files_with_multiple_owners: DefaultDict[str, set[Address]] = defaultdict(set) for tgt, sources in zip(tgts, sources_per_target): - for f in sources.files: - if f in files_to_addresses: - files_with_multiple_owners[f].update({files_to_addresses[f], tgt.address}) - else: - files_to_addresses[f] = tgt.address + assert len(sources.files) == 1 + fp = sources.files[0] + if fp in files_to_addresses: + files_with_multiple_owners[fp].update({files_to_addresses[fp], tgt.address}) + else: + files_to_addresses[fp] = tgt.address # Remove files with ambiguous owners. for ambiguous_f in files_with_multiple_owners: @@ -92,9 +93,6 @@ class ParsedShellImports(DeduplicatedCollection): @dataclass(frozen=True) class ParseShellImportsRequest: - # NB: We parse per-file, rather than per-target. This is necessary so that we can have each - # file in complete isolation without its sibling files present so that Shellcheck errors when - # trying to source a sibling file, which then allows us to extract that path. digest: Digest fp: str @@ -176,32 +174,30 @@ async def infer_shell_dependencies( Get(ExplicitlyProvidedDependencies, DependenciesRequest(wrapped_tgt.target[Dependencies])), Get(HydratedSources, HydrateSourcesRequest(request.sources_field)), ) - per_file_digests = await MultiGet( - Get(Digest, DigestSubset(hydrated_sources.snapshot.digest, PathGlobs([f]))) - for f in hydrated_sources.snapshot.files - ) - all_detected_imports = await MultiGet( - Get(ParsedShellImports, ParseShellImportsRequest(digest, f)) - for digest, f in zip(per_file_digests, hydrated_sources.snapshot.files) - ) + assert len(hydrated_sources.snapshot.files) == 1 + detected_imports = await Get( + ParsedShellImports, + ParseShellImportsRequest( + hydrated_sources.snapshot.digest, hydrated_sources.snapshot.files[0] + ), + ) result: OrderedSet[Address] = OrderedSet() - for detected_imports in all_detected_imports: - for import_path in detected_imports: - unambiguous = shell_mapping.mapping.get(import_path) - ambiguous = shell_mapping.ambiguous_modules.get(import_path) - if unambiguous: - result.add(unambiguous) - elif ambiguous: - explicitly_provided_deps.maybe_warn_of_ambiguous_dependency_inference( - ambiguous, - address, - import_reference="file", - context=f"The target {address} sources `{import_path}`", - ) - maybe_disambiguated = explicitly_provided_deps.disambiguated(ambiguous) - if maybe_disambiguated: - result.add(maybe_disambiguated) + for import_path in detected_imports: + unambiguous = shell_mapping.mapping.get(import_path) + ambiguous = shell_mapping.ambiguous_modules.get(import_path) + if unambiguous: + result.add(unambiguous) + elif ambiguous: + explicitly_provided_deps.maybe_warn_of_ambiguous_dependency_inference( + ambiguous, + address, + import_reference="file", + context=f"The target {address} sources `{import_path}`", + ) + maybe_disambiguated = explicitly_provided_deps.disambiguated(ambiguous) + if maybe_disambiguated: + result.add(maybe_disambiguated) return InferredDependencies(sorted(result)) diff --git a/src/python/pants/engine/internals/graph_test.py b/src/python/pants/engine/internals/graph_test.py index d27fbad6f15..d7ce75758eb 100644 --- a/src/python/pants/engine/internals/graph_test.py +++ b/src/python/pants/engine/internals/graph_test.py @@ -1311,17 +1311,6 @@ def hydrate(sources_cls: Type[MultipleSourcesField], sources: Iterable[str]) -> "f3.txt", ) - # `SingleSourceField` must have one file. - with engine_error(contains="must have 1 file"): - sources_rule_runner.request( - HydratedSources, - [ - HydrateSourcesRequest( - SingleSourceField("*.txt", Address("", target_name="example")) - ), - ], - ) - # ----------------------------------------------------------------------------------------------- # Test codegen. Also see `engine/target_test.py`. diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index cf17da1e06d..b9926ecaa83 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -1697,6 +1697,25 @@ class SingleSourceField(SourcesField, StringField): required = True expected_num_files: ClassVar[int | range] = 1 # Can set to `range(0, 2)` for 0-1 files. + @classmethod + def compute_value(cls, raw_value: Optional[str], address: Address) -> Optional[str]: + value_or_default = super().compute_value(raw_value, address) + if value_or_default is None: + return None + if "*" in value_or_default: + raise InvalidFieldException( + f"The {repr(cls.alias)} field in target {address} should not include `*` globs, " + f"but was set to {value_or_default}. Instead, use a literal file path (relative " + "to the BUILD file)." + ) + if value_or_default.startswith("!"): + raise InvalidFieldException( + f"The {repr(cls.alias)} field in target {address} should not start with `!`, which " + f"is usually used in the `sources` field to exclude certain files. Instead, use a " + "literal file path (relative to the BUILD file)." + ) + return value_or_default + @property def globs(self) -> tuple[str, ...]: # Subclasses might override `required = False`, so `self.value` could be `None`. diff --git a/src/python/pants/engine/target_test.py b/src/python/pants/engine/target_test.py index aab36b7a054..36db3a60226 100644 --- a/src/python/pants/engine/target_test.py +++ b/src/python/pants/engine/target_test.py @@ -1008,6 +1008,16 @@ class GenSources(GenerateSourcesRequest): assert set(result) == {tgt2} +def test_single_source_field_bans_globs() -> None: + class TestSingleSourceField(SingleSourceField): + pass + + with pytest.raises(InvalidFieldException): + TestSingleSourceField("*.ext", Address("project")) + with pytest.raises(InvalidFieldException): + TestSingleSourceField("!f.ext", Address("project")) + + # ----------------------------------------------------------------------------------------------- # Test `ExplicitlyProvidedDependencies` helper functions # -----------------------------------------------------------------------------------------------