Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plugins to add custom schema/authority URL handler rules #17898

Merged
merged 7 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ show_traceback = true

[[tool.mypy.overrides]]
module = [
"botocore",
"bs4",
"chevron",
"colors",
Expand Down
4 changes: 4 additions & 0 deletions src/python/pants/backend/url_handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
Empty file.
5 changes: 5 additions & 0 deletions src/python/pants/backend/url_handlers/s3/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
python_tests(name="tests")
Empty file.
107 changes: 107 additions & 0 deletions src/python/pants/backend/url_handlers/s3/integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import sys
from http.server import BaseHTTPRequestHandler
from types import SimpleNamespace

import pytest

from pants.backend.url_handlers.s3.register import DownloadS3SchemeURL
from pants.backend.url_handlers.s3.register import rules as s3_rules
from pants.engine.fs import Digest, FileDigest, NativeDownloadFile, Snapshot
from pants.engine.rules import QueryRule
from pants.testutil.rule_runner import RuleRunner
from pants.util.contextutil import http_server

DOWNLOADS_FILE_DIGEST = FileDigest(
"8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14
)
DOWNLOADS_EXPECTED_DIRECTORY_DIGEST = Digest(
"4c9cf91fcd7ba1abbf7f9a0a1c8175556a82bee6a398e34db3284525ac24a3ad", 84
)


@pytest.fixture
def rule_runner() -> RuleRunner:
return RuleRunner(
rules=[
*s3_rules(),
QueryRule(Snapshot, [DownloadS3SchemeURL]),
],
isolated_local_store=True,
)


@pytest.fixture
def monkeypatch_botocore(monkeypatch):
botocore = SimpleNamespace()
fake_session = object()
fake_creds = SimpleNamespace(access_key="ACCESS", secret_key="SECRET")
botocore.session = SimpleNamespace(get_session=lambda: fake_session)

def fake_resolver_creator(session):
assert session is fake_session
return SimpleNamespace(load_credentials=lambda: fake_creds)

def fake_creds_ctor(access_key, secret_key):
assert access_key == fake_creds.access_key
assert secret_key == fake_creds.secret_key
return fake_creds

botocore.credentials = SimpleNamespace(
create_credential_resolver=fake_resolver_creator, Credentials=fake_creds_ctor
)

def fake_auth_ctor(creds):
assert creds is fake_creds
return SimpleNamespace(
add_auth=lambda request: request.headers.__setitem__("AUTH", "TOKEN")
)

botocore.auth = SimpleNamespace(SigV3Auth=fake_auth_ctor)

monkeypatch.setitem(sys.modules, "botocore", botocore)


@pytest.fixture
def replace_url(monkeypatch):
def with_port(port):
old_native_download_file_init = NativeDownloadFile.__init__

def new_init(self, **kwargs):
assert kwargs["url"] == "https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt"
kwargs["url"] = f"http://localhost:{port}/file.txt"
return old_native_download_file_init(self, **kwargs)

monkeypatch.setattr(NativeDownloadFile, "__init__", new_init)

return with_port


def test_download_s3(rule_runner: RuleRunner, monkeypatch_botocore, replace_url) -> None:
class S3HTTPHandler(BaseHTTPRequestHandler):
response_text = b"Hello, client!"

def do_HEAD(self):
self.send_headers()

def do_GET(self):
self.send_headers()
self.wfile.write(self.response_text)

def send_headers(self):
assert self.headers["AUTH"] == "TOKEN"
self.send_response(200)
self.send_header("Content-Type", "binary/octet-stream")
self.send_header("Content-Length", f"{len(self.response_text)}")
self.end_headers()

with http_server(S3HTTPHandler) as port:
replace_url(port)
snapshot = rule_runner.request(
Snapshot,
[DownloadS3SchemeURL("s3://bucket/keypart1/keypart2/file.txt", DOWNLOADS_FILE_DIGEST)],
)
assert snapshot.files == ("file.txt",)
assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST
146 changes: 146 additions & 0 deletions src/python/pants/backend/url_handlers/s3/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import logging
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
from urllib.parse import urlparse

from pants.engine.download_file import URLDownloadHandler
from pants.engine.fs import Digest, NativeDownloadFile
from pants.engine.internals.selectors import Get
from pants.engine.rules import collect_rules, rule
from pants.engine.unions import UnionRule
from pants.util.strutil import softwrap

CONTENT_TYPE = "binary/octet-stream"


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class AWSCredentials:
creds: Any


@rule
async def access_aws_credentials() -> AWSCredentials:
try:
from botocore import credentials, session
except ImportError:
logger.warning(
softwrap(
"""
In order to resolve s3:// URLs, Pants must load AWS credentials. To do so, `botocore`
must be importable in Pants' environment.

To do that add an entry to `[GLOBAL].plugins` of a pip-resolvable package to download from PyPI.
(E.g. `botocore == 1.29.39`). Note that the `botocore` package from PyPI at the time
of writing is >70MB, so an alternate package providing the `botocore` modules may be
advisable.
"""
)
)
raise

session = session.get_session()
creds = credentials.create_credential_resolver(session).load_credentials()

return AWSCredentials(creds)


# NB: The URL is expected to be in path-style
# See https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html
def _get_aws_auth_headers(url: str, aws_credentials: AWSCredentials):
from botocore import auth # pants: no-infer-dep

request = SimpleNamespace(
url=url,
headers={},
method="GET",
auth_path=None,
)
auth.HmacV1Auth(aws_credentials.creds).add_auth(request)
return request.headers


class DownloadS3SchemeURL(URLDownloadHandler):
match_scheme = "s3"


@rule
async def download_file_from_s3_scheme(
request: DownloadS3SchemeURL, aws_credentials: AWSCredentials
) -> Digest:
parsed_url = urlparse(request.url)
bucket = parsed_url.netloc
key = parsed_url.path
http_url = f"https://s3.amazonaws.com/{bucket}{key}"
headers = _get_aws_auth_headers(http_url, aws_credentials)

digest = await Get(
Digest,
NativeDownloadFile(
url=http_url,
expected_digest=request.expected_digest,
auth_headers=headers,
),
)
return digest


class DownloadS3AuthorityVirtualHostedStyleURL(URLDownloadHandler):
match_authority = "*.s3*amazonaws.com"


@rule
async def download_file_from_virtual_hosted_s3_authority(
request: DownloadS3AuthorityVirtualHostedStyleURL, aws_credentials: AWSCredentials
) -> Digest:
parsed_url = urlparse(request.url)
bucket = parsed_url.netloc.split(".", 1)[0]
# NB: Turn this into a path-style request
path_style_url = f"https://s3.amazonaws.com/{bucket}{parsed_url.path}"
if parsed_url.query:
path_style_url += f"?{parsed_url.query}"
headers = _get_aws_auth_headers(path_style_url, aws_credentials)

digest = await Get(
Digest,
NativeDownloadFile(
url=request.url,
expected_digest=request.expected_digest,
auth_headers=headers,
),
)
return digest


class DownloadS3AuthorityPathStyleURL(URLDownloadHandler):
match_authority = "s3.*amazonaws.com"


@rule
async def download_file_from_path_s3_authority(
request: DownloadS3AuthorityPathStyleURL, aws_credentials: AWSCredentials
) -> Digest:
headers = _get_aws_auth_headers(request.url, aws_credentials)
digest = await Get(
Digest,
NativeDownloadFile(
url=request.url,
expected_digest=request.expected_digest,
auth_headers=headers,
),
)
return digest


def rules():
return [
UnionRule(URLDownloadHandler, DownloadS3SchemeURL),
UnionRule(URLDownloadHandler, DownloadS3AuthorityVirtualHostedStyleURL),
UnionRule(URLDownloadHandler, DownloadS3AuthorityPathStyleURL),
*collect_rules(),
]
103 changes: 103 additions & 0 deletions src/python/pants/engine/download_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from fnmatch import fnmatch
from typing import ClassVar, Optional
from urllib.parse import urlparse

from pants.engine.fs import Digest, DownloadFile, NativeDownloadFile
from pants.engine.internals.native_engine import FileDigest
from pants.engine.internals.selectors import Get
from pants.engine.rules import collect_rules, rule
from pants.engine.unions import UnionMembership, union
from pants.util.strutil import bullet_list, softwrap


@union
@dataclass(frozen=True)
class URLDownloadHandler:
"""Union base for custom URL handler.

To register a custom URL handler:
- Subclass this class and declare one or both of the ClassVars.
- Declare a rule that takes in your class type and returns a `Digest`.
- Register your union member in your `rules()`: `UnionRule(URLDownloadHandler, YourClass)`.

Example:

class S3DownloadHandler(URLDownloadHandler):
match_scheme = "s3"

@rule
async def download_s3_file(request: S3DownloadHandler) -> Digest:
# Lookup auth tokens, etc...
# Ideally, download the file using `NativeDownloadFile()`
return digest

def rules():
return [
*collect_rules(),
UnionRule(URLDownloadHandler, S3DownloadHandler),
]
"""

match_scheme: ClassVar[Optional[str]] = None
"""The scheme to match (e.g. 'ftp' or 's3') or `None` to match all schemes.

The scheme is matched using `fnmatch`, see https://docs.python.org/3/library/fnmatch.html for more
information.
"""

match_authority: ClassVar[Optional[str]] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is allowing wildcards. I suspect at <company> we'll want to use https://*.s3.amazonaws.com. I can upstream that as well.

"""The authority to match (e.g. 'pantsbuild.org' or 's3.amazonaws.com') or `None` to match all authorities.

The authority is matched using `fnmatch`, see https://docs.python.org/3/library/fnmatch.html for more
information.

Note that the authority matches userinfo (e.g. '[email protected]' or 'me:[email protected]')
as well as port (e.g. 'pantsbuild.org:80').
"""

url: str
expected_digest: FileDigest


@rule
async def download_file(
request: DownloadFile,
union_membership: UnionMembership,
) -> Digest:
parsed_url = urlparse(request.url)
handlers = union_membership.get(URLDownloadHandler)
thejcannon marked this conversation as resolved.
Show resolved Hide resolved
matched_handlers = []
for handler in handlers:
matches_scheme = handler.match_scheme is None or fnmatch(
parsed_url.scheme, handler.match_scheme
)
matches_authority = handler.match_authority is None or fnmatch(
parsed_url.netloc, handler.match_authority
)
if matches_scheme and matches_authority:
matched_handlers.append(handler)

if len(matched_handlers) > 1:
raise Exception(
softwrap(
f"""
Too many registered URL handlers match the URL '{request.url}'.

Matched handlers:
{bullet_list(map(str, handlers))}
"""
)
)
if len(matched_handlers) == 1:
handler = matched_handlers[0]
return await Get(Digest, URLDownloadHandler, handler(request.url, request.expected_digest))

return await Get(Digest, NativeDownloadFile(request.url, request.expected_digest))


def rules():
return collect_rules()
Loading