diff --git a/pyproject.toml b/pyproject.toml index 32261b79ce7..5074859c561 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ show_traceback = true [[tool.mypy.overrides]] module = [ + "botocore", "bs4", "chevron", "colors", diff --git a/src/python/pants/backend/url_handlers/BUILD b/src/python/pants/backend/url_handlers/BUILD new file mode 100644 index 00000000000..95c6150585e --- /dev/null +++ b/src/python/pants/backend/url_handlers/BUILD @@ -0,0 +1,4 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +python_sources() diff --git a/src/python/pants/backend/url_handlers/__init__.py b/src/python/pants/backend/url_handlers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/python/pants/backend/url_handlers/s3/BUILD b/src/python/pants/backend/url_handlers/s3/BUILD new file mode 100644 index 00000000000..4f317688fd6 --- /dev/null +++ b/src/python/pants/backend/url_handlers/s3/BUILD @@ -0,0 +1,5 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +python_sources() +python_tests(name="tests") diff --git a/src/python/pants/backend/url_handlers/s3/__init__.py b/src/python/pants/backend/url_handlers/s3/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/python/pants/backend/url_handlers/s3/integration_test.py b/src/python/pants/backend/url_handlers/s3/integration_test.py new file mode 100644 index 00000000000..b1975f9f83d --- /dev/null +++ b/src/python/pants/backend/url_handlers/s3/integration_test.py @@ -0,0 +1,190 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +import sys +from http.server import BaseHTTPRequestHandler +from types import SimpleNamespace + +import pytest + +from pants.backend.url_handlers.s3.register import ( + DownloadS3AuthorityPathStyleURL, + DownloadS3AuthorityVirtualHostedStyleURL, + DownloadS3SchemeURL, +) +from pants.backend.url_handlers.s3.register import rules as s3_rules +from pants.engine.fs import Digest, FileDigest, NativeDownloadFile, Snapshot +from pants.engine.rules import QueryRule +from pants.testutil.rule_runner import RuleRunner +from pants.util.contextutil import http_server + +DOWNLOADS_FILE_DIGEST = FileDigest( + "8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14 +) +DOWNLOADS_EXPECTED_DIRECTORY_DIGEST = Digest( + "4c9cf91fcd7ba1abbf7f9a0a1c8175556a82bee6a398e34db3284525ac24a3ad", 84 +) + + +@pytest.fixture +def rule_runner() -> RuleRunner: + return RuleRunner( + rules=[ + *s3_rules(), + QueryRule(Snapshot, [DownloadS3SchemeURL]), + QueryRule(Snapshot, [DownloadS3AuthorityVirtualHostedStyleURL]), + QueryRule(Snapshot, [DownloadS3AuthorityPathStyleURL]), + ], + isolated_local_store=True, + ) + + +@pytest.fixture +def monkeypatch_botocore(monkeypatch): + def do_patching(expected_url): + botocore = SimpleNamespace() + botocore.exceptions = SimpleNamespace(NoCredentialsError=Exception) + fake_session = object() + fake_creds = SimpleNamespace(access_key="ACCESS", secret_key="SECRET") + botocore.session = SimpleNamespace(get_session=lambda: fake_session) + + def fake_resolver_creator(session): + assert session is fake_session + return SimpleNamespace(load_credentials=lambda: fake_creds) + + def fake_creds_ctor(access_key, secret_key): + assert access_key == fake_creds.access_key + assert secret_key == fake_creds.secret_key + return fake_creds + + botocore.credentials = SimpleNamespace( + create_credential_resolver=fake_resolver_creator, Credentials=fake_creds_ctor + ) + + def fake_auth_ctor(creds): + assert creds is fake_creds + + def add_auth(request): + request.url == expected_url + request.headers["AUTH"] = "TOKEN" + + return SimpleNamespace(add_auth=add_auth) + + botocore.auth = SimpleNamespace(HmacV1Auth=fake_auth_ctor) + + monkeypatch.setitem(sys.modules, "botocore", botocore) + + return do_patching + + +@pytest.fixture +def replace_url(monkeypatch): + def with_port(expected_url, port): + old_native_download_file_init = NativeDownloadFile.__init__ + + def new_init(self, **kwargs): + assert kwargs["url"] == expected_url + kwargs["url"] = f"http://localhost:{port}/file.txt" + return old_native_download_file_init(self, **kwargs) + + monkeypatch.setattr(NativeDownloadFile, "__init__", new_init) + + return with_port + + +@pytest.mark.parametrize( + "request_url, expected_auth_url, expected_native_url, req_type", + [ + ( + "s3://bucket/keypart1/keypart2/file.txt", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt", + DownloadS3SchemeURL, + ), + # Path-style + ( + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt", + DownloadS3AuthorityPathStyleURL, + ), + ( + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + DownloadS3AuthorityPathStyleURL, + ), + ( + "https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt", + DownloadS3AuthorityPathStyleURL, + ), + ( + "https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + DownloadS3AuthorityPathStyleURL, + ), + # Virtual-hosted-style + ( + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt", + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt", + DownloadS3AuthorityVirtualHostedStyleURL, + ), + ( + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + DownloadS3AuthorityVirtualHostedStyleURL, + ), + ( + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + "https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123", + "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123", + DownloadS3AuthorityVirtualHostedStyleURL, + ), + ( + "https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt", + "https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt", + "https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt", + DownloadS3AuthorityVirtualHostedStyleURL, + ), + ], +) +def test_download_s3( + rule_runner: RuleRunner, + monkeypatch_botocore, + request_url: str, + expected_auth_url: str, + expected_native_url: str, + req_type: type, + replace_url, +) -> None: + class S3HTTPHandler(BaseHTTPRequestHandler): + response_text = b"Hello, client!" + + def do_HEAD(self): + self.send_headers() + + def do_GET(self): + self.send_headers() + self.wfile.write(self.response_text) + + def send_headers(self): + assert self.headers["AUTH"] == "TOKEN" + self.send_response(200) + self.send_header("Content-Type", "binary/octet-stream") + self.send_header("Content-Length", f"{len(self.response_text)}") + self.end_headers() + + monkeypatch_botocore(expected_auth_url) + with http_server(S3HTTPHandler) as port: + replace_url(expected_native_url, port) + snapshot = rule_runner.request( + Snapshot, + [req_type(request_url, DOWNLOADS_FILE_DIGEST)], + ) + assert snapshot.files == ("file.txt",) + assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST diff --git a/src/python/pants/backend/url_handlers/s3/register.py b/src/python/pants/backend/url_handlers/s3/register.py new file mode 100644 index 00000000000..862418e1bc5 --- /dev/null +++ b/src/python/pants/backend/url_handlers/s3/register.py @@ -0,0 +1,175 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). +import logging +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any +from urllib.parse import urlsplit + +from pants.engine.download_file import URLDownloadHandler +from pants.engine.fs import Digest, NativeDownloadFile +from pants.engine.internals.native_engine import FileDigest +from pants.engine.internals.selectors import Get +from pants.engine.rules import collect_rules, rule +from pants.engine.unions import UnionRule +from pants.util.strutil import softwrap + +CONTENT_TYPE = "binary/octet-stream" + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class AWSCredentials: + creds: Any + + +@rule +async def access_aws_credentials() -> AWSCredentials: + try: + from botocore import credentials, session + except ImportError: + logger.warning( + softwrap( + """ + In order to resolve s3:// URLs, Pants must load AWS credentials. To do so, `botocore` + must be importable in Pants' environment. + + To do that add an entry to `[GLOBAL].plugins` of a pip-resolvable package to download from PyPI. + (E.g. `botocore == 1.29.39`). Note that the `botocore` package from PyPI at the time + of writing is >70MB, so an alternate package providing the `botocore` modules may be + advisable. + """ + ) + ) + raise + + session = session.get_session() + creds = credentials.create_credential_resolver(session).load_credentials() + + return AWSCredentials(creds) + + +@dataclass(frozen=True) +class S3DownloadFile: + region: str + bucket: str + key: str + query: str + expected_digest: FileDigest + + +@rule +async def download_from_s3(request: S3DownloadFile, aws_credentials: AWSCredentials) -> Digest: + from botocore import auth, exceptions # pants: no-infer-dep + + # NB: The URL for auth is expected to be in path-style + path_style_url = "https://s3" + if request.region: + path_style_url += f".{request.region}" + path_style_url += f".amazonaws.com/{request.bucket}/{request.key}" + if request.query: + path_style_url += f"?{request.query}" + + http_request = SimpleNamespace( + url=path_style_url, + headers={}, + method="GET", + auth_path=None, + ) + # NB: The added Auth header doesn't need to be valid when accessing a public bucket. When + # hand-testing, you MUST test against a private bucket to ensure it works for private buckets too. + signer = auth.HmacV1Auth(aws_credentials.creds) + try: + signer.add_auth(http_request) + except exceptions.NoCredentialsError: + pass # The user can still access public S3 buckets without credentials + + virtual_hosted_url = f"https://{request.bucket}.s3" + if request.region: + virtual_hosted_url += f".{request.region}" + virtual_hosted_url += f".amazonaws.com/{request.key}" + if request.query: + virtual_hosted_url += f"?{request.query}" + + return await Get( + Digest, + NativeDownloadFile( + url=virtual_hosted_url, + expected_digest=request.expected_digest, + auth_headers=http_request.headers, + ), + ) + + +class DownloadS3SchemeURL(URLDownloadHandler): + match_scheme = "s3" + + +@rule +async def download_file_from_s3_scheme(request: DownloadS3SchemeURL) -> Digest: + split = urlsplit(request.url) + return await Get( + Digest, + S3DownloadFile( + region="", + bucket=split.netloc, + key=split.path[1:], + query="", + expected_digest=request.expected_digest, + ), + ) + + +class DownloadS3AuthorityVirtualHostedStyleURL(URLDownloadHandler): + match_authority = "*.s3*amazonaws.com" + + +@rule +async def download_file_from_virtual_hosted_s3_authority( + request: DownloadS3AuthorityVirtualHostedStyleURL, aws_credentials: AWSCredentials +) -> Digest: + split = urlsplit(request.url) + bucket, aws_netloc = split.netloc.split(".", 1) + return await Get( + Digest, + S3DownloadFile( + region=aws_netloc.split(".")[1] if aws_netloc.count(".") == 3 else "", + bucket=bucket, + key=split.path[1:], + query=split.query, + expected_digest=request.expected_digest, + ), + ) + + +class DownloadS3AuthorityPathStyleURL(URLDownloadHandler): + match_authority = "s3.*amazonaws.com" + + +@rule +async def download_file_from_path_s3_authority( + request: DownloadS3AuthorityPathStyleURL, aws_credentials: AWSCredentials +) -> Digest: + split = urlsplit(request.url) + _, bucket, key = split.path.split("/", 2) + return await Get( + Digest, + S3DownloadFile( + region=split.netloc.split(".")[1] if split.netloc.count(".") == 3 else "", + bucket=bucket, + key=key, + query=split.query, + expected_digest=request.expected_digest, + ), + ) + + +def rules(): + return [ + UnionRule(URLDownloadHandler, DownloadS3SchemeURL), + UnionRule(URLDownloadHandler, DownloadS3AuthorityVirtualHostedStyleURL), + UnionRule(URLDownloadHandler, DownloadS3AuthorityPathStyleURL), + *collect_rules(), + ] diff --git a/src/python/pants/engine/download_file.py b/src/python/pants/engine/download_file.py new file mode 100644 index 00000000000..5c707c47baf --- /dev/null +++ b/src/python/pants/engine/download_file.py @@ -0,0 +1,103 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +from dataclasses import dataclass +from fnmatch import fnmatch +from typing import ClassVar, Optional +from urllib.parse import urlparse + +from pants.engine.fs import Digest, DownloadFile, NativeDownloadFile +from pants.engine.internals.native_engine import FileDigest +from pants.engine.internals.selectors import Get +from pants.engine.rules import collect_rules, rule +from pants.engine.unions import UnionMembership, union +from pants.util.strutil import bullet_list, softwrap + + +@union +@dataclass(frozen=True) +class URLDownloadHandler: + """Union base for custom URL handler. + + To register a custom URL handler: + - Subclass this class and declare one or both of the ClassVars. + - Declare a rule that takes in your class type and returns a `Digest`. + - Register your union member in your `rules()`: `UnionRule(URLDownloadHandler, YourClass)`. + + Example: + + class S3DownloadHandler(URLDownloadHandler): + match_scheme = "s3" + + @rule + async def download_s3_file(request: S3DownloadHandler) -> Digest: + # Lookup auth tokens, etc... + # Ideally, download the file using `NativeDownloadFile()` + return digest + + def rules(): + return [ + *collect_rules(), + UnionRule(URLDownloadHandler, S3DownloadHandler), + ] + """ + + match_scheme: ClassVar[Optional[str]] = None + """The scheme to match (e.g. 'ftp' or 's3') or `None` to match all schemes. + + The scheme is matched using `fnmatch`, see https://docs.python.org/3/library/fnmatch.html for more + information. + """ + + match_authority: ClassVar[Optional[str]] = None + """The authority to match (e.g. 'pantsbuild.org' or 's3.amazonaws.com') or `None` to match all authorities. + + The authority is matched using `fnmatch`, see https://docs.python.org/3/library/fnmatch.html for more + information. + + Note that the authority matches userinfo (e.g. 'me@pantsbuild.org' or 'me:password@pantsbuild.org') + as well as port (e.g. 'pantsbuild.org:80'). + """ + + url: str + expected_digest: FileDigest + + +@rule +async def download_file( + request: DownloadFile, + union_membership: UnionMembership, +) -> Digest: + parsed_url = urlparse(request.url) + handlers = union_membership.get(URLDownloadHandler) + matched_handlers = [] + for handler in handlers: + matches_scheme = handler.match_scheme is None or fnmatch( + parsed_url.scheme, handler.match_scheme + ) + matches_authority = handler.match_authority is None or fnmatch( + parsed_url.netloc, handler.match_authority + ) + if matches_scheme and matches_authority: + matched_handlers.append(handler) + + if len(matched_handlers) > 1: + raise Exception( + softwrap( + f""" + Too many registered URL handlers match the URL '{request.url}'. + + Matched handlers: + {bullet_list(map(str, handlers))} + """ + ) + ) + if len(matched_handlers) == 1: + handler = matched_handlers[0] + return await Get(Digest, URLDownloadHandler, handler(request.url, request.expected_digest)) + + return await Get(Digest, NativeDownloadFile(request.url, request.expected_digest)) + + +def rules(): + return collect_rules() diff --git a/src/python/pants/engine/download_file_integration_test.py b/src/python/pants/engine/download_file_integration_test.py new file mode 100644 index 00000000000..1848fcaa77e --- /dev/null +++ b/src/python/pants/engine/download_file_integration_test.py @@ -0,0 +1,174 @@ +# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +import pytest + +from pants.engine.download_file import URLDownloadHandler, download_file +from pants.engine.fs import Digest, DownloadFile, FileDigest, NativeDownloadFile +from pants.engine.unions import UnionMembership +from pants.testutil.rule_runner import MockGet, run_rule_with_mocks + +DOWNLOADS_FILE_DIGEST = FileDigest( + "8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14 +) +DOWNLOADS_EXPECTED_DIRECTORY_DIGEST = Digest( + "4c9cf91fcd7ba1abbf7f9a0a1c8175556a82bee6a398e34db3284525ac24a3ad", 84 +) + + +def test_no_union_members() -> None: + union_membership = UnionMembership({}) + digest = run_rule_with_mocks( + download_file, + rule_args=[ + DownloadFile("http://pantsbuild.com/file.txt", DOWNLOADS_FILE_DIGEST), + union_membership, + ], + mock_gets=[ + MockGet( + output_type=Digest, + input_types=(URLDownloadHandler,), + mock=lambda _: None, + ), + MockGet( + output_type=Digest, + input_types=(NativeDownloadFile,), + mock=lambda _: DOWNLOADS_EXPECTED_DIRECTORY_DIGEST, + ), + ], + union_membership=union_membership, + ) + assert digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST + + +@pytest.mark.parametrize( + "scheme, authority, url", + [ + # Anything (every URL matches) + (None, None, "s3://pantsbuild.com/file.txt"), + (None, None, "http://pantsbuild.com/file.txt"), + (None, None, "http://awesome.pantsbuild.com/file.txt"), + # Scheme + ("s3", None, "s3://pantsbuild.com/file.txt"), + ("http*", None, "http://pantsbuild.com/file.txt"), + ("http*", None, "https://pantsbuild.com/file.txt"), + # Authority + (None, "pantsbuild.com", "s3://pantsbuild.com/file.txt"), + (None, "pantsbuild.com", "http://pantsbuild.com/file.txt"), + (None, "pantsbuild.com", "https://pantsbuild.com/file.txt"), + (None, "*.pantsbuild.com", "https://awesome.pantsbuild.com/file.txt"), + (None, "*.pantsbuild.com*", "https://awesome.pantsbuild.com/file.txt"), + (None, "*.pantsbuild.com*", "https://awesome.pantsbuild.com:80/file.txt"), + # Both + ("http*", "*.pantsbuild.com", "http://awesome.pantsbuild.com/file.txt"), + ], +) +def test_matches(scheme, authority, url) -> None: + class UnionMember(URLDownloadHandler): + match_scheme = scheme + match_authority = authority + + def mock_rule(self) -> Digest: + assert isinstance(self, UnionMember) + return DOWNLOADS_EXPECTED_DIRECTORY_DIGEST + + union_membership = UnionMembership({URLDownloadHandler: [UnionMember]}) + + digest = run_rule_with_mocks( + download_file, + rule_args=[ + DownloadFile(url, DOWNLOADS_FILE_DIGEST), + union_membership, + ], + mock_gets=[ + MockGet( + output_type=Digest, + input_types=(URLDownloadHandler,), + mock=UnionMember.mock_rule, + ), + MockGet( + output_type=Digest, + input_types=(NativeDownloadFile,), + mock=lambda _: None, + ), + ], + union_membership=union_membership, + ) + assert digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST + + +@pytest.mark.parametrize( + "scheme, authority, url", + [ + # Scheme + ("s3", None, "http://pantsbuild.com/file.txt"), + ("s3", None, "as3://pantsbuild.com/file.txt"), + ("http", None, "https://pantsbuild.com/file.txt"), + # Authority + (None, "pantsbuild.com", "http://pantsbuild.com:80/file.txt"), + (None, "*.pantsbuild.com", "https://pantsbuild.com/file.txt"), + # Both + ("http", "*.pantsbuild.com", "https://awesome.pantsbuild.com/file.txt"), + ("https", "*.pantsbuild.com", "https://pantsbuild.com/file.txt"), + ], +) +def test_doesnt_match(scheme, authority, url) -> None: + class UnionMember(URLDownloadHandler): + match_scheme = scheme + match_authority = authority + + union_membership = UnionMembership({URLDownloadHandler: [UnionMember]}) + + digest = run_rule_with_mocks( + download_file, + rule_args=[ + DownloadFile(url, DOWNLOADS_FILE_DIGEST), + union_membership, + ], + mock_gets=[ + MockGet( + output_type=Digest, + input_types=(URLDownloadHandler,), + mock=lambda _: None, + ), + MockGet( + output_type=Digest, + input_types=(NativeDownloadFile,), + mock=lambda _: DOWNLOADS_EXPECTED_DIRECTORY_DIGEST, + ), + ], + union_membership=union_membership, + ) + assert digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST + + +def test_too_many_matches() -> None: + class AuthorityMatcher(URLDownloadHandler): + match_authority = "pantsbuild.com" + + class SchemeMatcher(URLDownloadHandler): + match_scheme = "http" + + union_membership = UnionMembership({URLDownloadHandler: [AuthorityMatcher, SchemeMatcher]}) + + with pytest.raises(Exception, match=r"Too many registered URL handlers"): + run_rule_with_mocks( + download_file, + rule_args=[ + DownloadFile("http://pantsbuild.com/file.txt", DOWNLOADS_FILE_DIGEST), + union_membership, + ], + mock_gets=[ + MockGet( + output_type=Digest, + input_types=(URLDownloadHandler,), + mock=lambda _: None, + ), + MockGet( + output_type=Digest, + input_types=(NativeDownloadFile,), + mock=lambda _: DOWNLOADS_EXPECTED_DIRECTORY_DIGEST, + ), + ], + union_membership=union_membership, + ) diff --git a/src/python/pants/engine/fs.py b/src/python/pants/engine/fs.py index 9c1422020b6..eaad319325c 100644 --- a/src/python/pants/engine/fs.py +++ b/src/python/pants/engine/fs.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Tuple, Union # Note: several of these types are re-exported as the public API of `engine/fs.py`. from pants.base.glob_match_error_behavior import GlobMatchErrorBehavior as GlobMatchErrorBehavior @@ -23,6 +23,7 @@ from pants.engine.internals.native_engine import RemovePrefix as RemovePrefix from pants.engine.internals.native_engine import Snapshot as Snapshot from pants.engine.rules import QueryRule +from pants.util.frozendict import FrozenDict from pants.util.meta import frozen_after_init if TYPE_CHECKING: @@ -248,6 +249,34 @@ class DownloadFile: expected_digest: FileDigest +@frozen_after_init +@dataclass(unsafe_hash=True) +class NativeDownloadFile: + """Retrieve the contents of a file via an HTTP GET request or directly for local file:// URLs. + + This request is handled directly by the native engine without any additional coercion by plugins, + and therefore should only be used in cases where the URL is known to be publicly accessible. + Otherwise, callers should use `DownloadFile`. + + The auth_headers are part of this nodes' cache key for memoization (changing a header invalidates + prior results) but are not part of the underlying cache key for the local/remote cache (changing + a header won't re-download a file if the file was previously downloaded). + """ + + url: str + expected_digest: FileDigest + # NB: This mapping can be of any arbitrary headers, but should be limited to those required for + # authorization. + auth_headers: FrozenDict[str, str] + + def __init__( + self, url: str, expected_digest: FileDigest, auth_headers: Mapping[str, str] | None = None + ) -> None: + self.url = url + self.expected_digest = expected_digest + self.auth_headers = FrozenDict(auth_headers or {}) + + @dataclass(frozen=True) class Workspace(SideEffecting): """A handle for operations that mutate the local filesystem.""" @@ -300,7 +329,7 @@ def rules(): QueryRule(Digest, (PathGlobs,)), QueryRule(Digest, (AddPrefix,)), QueryRule(Digest, (RemovePrefix,)), - QueryRule(Digest, (DownloadFile,)), + QueryRule(Digest, (NativeDownloadFile,)), QueryRule(Digest, (MergeDigests,)), QueryRule(Digest, (DigestSubset,)), QueryRule(DigestContents, (Digest,)), diff --git a/src/python/pants/engine/internals/scheduler.py b/src/python/pants/engine/internals/scheduler.py index dcaec6e6279..d77552a1316 100644 --- a/src/python/pants/engine/internals/scheduler.py +++ b/src/python/pants/engine/internals/scheduler.py @@ -22,10 +22,10 @@ DigestEntries, DigestSubset, Directory, - DownloadFile, FileContent, FileDigest, FileEntry, + NativeDownloadFile, PathGlobs, PathGlobsAndRoot, Paths, @@ -157,7 +157,7 @@ def __init__( path_globs=PathGlobs, create_digest=CreateDigest, digest_subset=DigestSubset, - download_file=DownloadFile, + native_download_file=NativeDownloadFile, platform=Platform, process=Process, process_result=FallibleProcessResult, diff --git a/src/python/pants/init/engine_initializer.py b/src/python/pants/init/engine_initializer.py index 968838a6bba..a79b1d42c39 100644 --- a/src/python/pants/init/engine_initializer.py +++ b/src/python/pants/init/engine_initializer.py @@ -16,7 +16,7 @@ from pants.build_graph.build_configuration import BuildConfiguration from pants.core.util_rules import environments, system_binaries from pants.core.util_rules.environments import determine_bootstrap_environment -from pants.engine import desktop, fs, process +from pants.engine import desktop, download_file, fs, process from pants.engine.console import Console from pants.engine.environment import EnvironmentName from pants.engine.fs import PathGlobs, Snapshot, Workspace @@ -272,6 +272,7 @@ def build_root_singleton() -> BuildRoot: *fs.rules(), *dep_rules.rules(), *desktop.rules(), + *download_file.rules(), *git_rules(), *graph.rules(), *specs_rules.rules(), diff --git a/src/rust/engine/src/downloads.rs b/src/rust/engine/src/downloads.rs index a99b4702894..01df8bc24a8 100644 --- a/src/rust/engine/src/downloads.rs +++ b/src/rust/engine/src/downloads.rs @@ -1,6 +1,7 @@ // Copyright 2021 Pants project contributors (see CONTRIBUTORS.md). // Licensed under the Apache License, Version 2.0 (see LICENSE). +use std::collections::BTreeMap; use std::io::{self, Write}; use std::pin::Pin; use std::sync::Arc; @@ -10,6 +11,7 @@ use bytes::{BufMut, Bytes}; use futures::stream::StreamExt; use hashing::Digest; use humansize::{file_size_opts, FileSize}; +use reqwest::header::{HeaderMap, HeaderName}; use reqwest::Error; use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::RetryIf; @@ -44,11 +46,21 @@ impl NetDownload { async fn start( core: &Arc, url: Url, + auth_headers: BTreeMap, file_name: String, ) -> Result { + let mut headers = HeaderMap::new(); + for (k, v) in &auth_headers { + headers.insert( + HeaderName::from_bytes(k.as_bytes()).unwrap(), + v.parse().unwrap(), + ); + } + let response = core .http_client .get(url.clone()) + .headers(headers) .send() .await .map_err(|err| StreamingError::Retryable(format!("Error downloading file: {err}"))) @@ -127,6 +139,7 @@ impl StreamingDownload for FileDownload { async fn attempt_download( core: &Arc, url: &Url, + auth_headers: &BTreeMap, file_name: String, expected_digest: Digest, ) -> Result<(Digest, Bytes), StreamingError> { @@ -144,7 +157,7 @@ async fn attempt_download( } Box::new(FileDownload::start(url.path(), file_name).await?) } else { - Box::new(NetDownload::start(core, url.clone(), file_name).await?) + Box::new(NetDownload::start(core, url.clone(), auth_headers.clone(), file_name).await?) } }; @@ -195,6 +208,7 @@ async fn attempt_download( pub async fn download( core: Arc, url: Url, + auth_headers: BTreeMap, file_name: String, expected_digest: hashing::Digest, ) -> Result<(), String> { @@ -215,7 +229,15 @@ pub async fn download( let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4); RetryIf::spawn( retry_strategy, - || attempt_download(&core2, &url, file_name.clone(), expected_digest), + || { + attempt_download( + &core2, + &url, + &auth_headers, + file_name.clone(), + expected_digest, + ) + }, |err: &StreamingError| matches!(err, StreamingError::Retryable(_)), ) .await diff --git a/src/rust/engine/src/externs/interface.rs b/src/rust/engine/src/externs/interface.rs index b89a7fd7c83..a7c46f937cc 100644 --- a/src/rust/engine/src/externs/interface.rs +++ b/src/rust/engine/src/externs/interface.rs @@ -183,7 +183,7 @@ impl PyTypes { path_globs: &PyType, create_digest: &PyType, digest_subset: &PyType, - download_file: &PyType, + native_download_file: &PyType, platform: &PyType, process: &PyType, process_result: &PyType, @@ -215,7 +215,7 @@ impl PyTypes { remove_prefix: TypeId::new(py.get_type::()), create_digest: TypeId::new(create_digest), digest_subset: TypeId::new(digest_subset), - download_file: TypeId::new(download_file), + native_download_file: TypeId::new(native_download_file), platform: TypeId::new(platform), process: TypeId::new(process), process_result: TypeId::new(process_result), diff --git a/src/rust/engine/src/intrinsics.rs b/src/rust/engine/src/intrinsics.rs index 0b28051e713..5d2bd4af328 100644 --- a/src/rust/engine/src/intrinsics.rs +++ b/src/rust/engine/src/intrinsics.rs @@ -48,6 +48,7 @@ pub struct Intrinsics { intrinsics: IndexMap, } +// NB: Keep in sync with `rules()` in `src/python/pants/engine/fs.py`. impl Intrinsics { pub fn new(types: &Types) -> Intrinsics { let mut intrinsics: IndexMap = IndexMap::new(); @@ -64,7 +65,7 @@ impl Intrinsics { Box::new(path_globs_to_paths), ); intrinsics.insert( - Intrinsic::new(types.directory_digest, types.download_file), + Intrinsic::new(types.directory_digest, types.native_download_file), Box::new(download_file_to_digest), ); intrinsics.insert( diff --git a/src/rust/engine/src/nodes.rs b/src/rust/engine/src/nodes.rs index 280c1ca22d0..34257b0a741 100644 --- a/src/rust/engine/src/nodes.rs +++ b/src/rust/engine/src/nodes.rs @@ -969,6 +969,7 @@ impl DownloadedFile { &self, core: Arc, url: Url, + auth_headers: BTreeMap, digest: hashing::Digest, ) -> Result { let file_name = url @@ -986,6 +987,7 @@ impl DownloadedFile { // See if we have observed this URL and Digest before: if so, see whether we already have the // Digest fetched. The extra layer of indirection through the PersistentCache is to sanity // check that a Digest has ever been observed at the given URL. + // NB: The auth_headers are not part of the key. let url_key = Self::url_key(&url, digest); let have_observed_url = core.local_cache.load(&url_key).await?.is_some(); @@ -999,7 +1001,7 @@ impl DownloadedFile { .is_ok()); if !usable_in_store { - downloads::download(core.clone(), url, file_name, digest).await?; + downloads::download(core.clone(), url, auth_headers, file_name, digest).await?; // The value was successfully fetched and matched the digest: record in the ObservedUrls // cache. core.local_cache.store(&url_key, Bytes::from("")).await?; @@ -1008,19 +1010,21 @@ impl DownloadedFile { } async fn run_node(self, context: Context) -> NodeResult { - let (url_str, expected_digest) = Python::with_gil(|py| { + let (url_str, expected_digest, auth_headers) = Python::with_gil(|py| { let py_download_file_val = self.0.to_value(); let py_download_file = (*py_download_file_val).as_ref(py); let url_str: String = externs::getattr(py_download_file, "url").unwrap(); + let auth_headers = externs::getattr_from_str_frozendict(py_download_file, "auth_headers"); let py_file_digest: PyFileDigest = externs::getattr(py_download_file, "expected_digest").unwrap(); - let res: NodeResult<(String, Digest)> = Ok((url_str, py_file_digest.0)); + let res: NodeResult<(String, Digest, BTreeMap)> = + Ok((url_str, py_file_digest.0, auth_headers)); res })?; let url = Url::parse(&url_str) .map_err(|err| throw(format!("Error parsing URL {}: {}", url_str, err)))?; self - .load_or_download(context.core, url, expected_digest) + .load_or_download(context.core, url, auth_headers, expected_digest) .await .map_err(throw) } diff --git a/src/rust/engine/src/types.rs b/src/rust/engine/src/types.rs index 1d1eecc4641..75998318911 100644 --- a/src/rust/engine/src/types.rs +++ b/src/rust/engine/src/types.rs @@ -20,7 +20,7 @@ pub struct Types { pub remove_prefix: TypeId, pub create_digest: TypeId, pub digest_subset: TypeId, - pub download_file: TypeId, + pub native_download_file: TypeId, pub platform: TypeId, pub process: TypeId, pub process_config_from_environment: TypeId,