Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plugins to add custom schema/authority URL handler rules #17898

Merged
merged 7 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ show_traceback = true

[[tool.mypy.overrides]]
module = [
"botocore",
"bs4",
"chevron",
"colors",
Expand Down
4 changes: 4 additions & 0 deletions src/python/pants/backend/url_handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
Empty file.
5 changes: 5 additions & 0 deletions src/python/pants/backend/url_handlers/s3/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
python_tests(name="tests")
Empty file.
190 changes: 190 additions & 0 deletions src/python/pants/backend/url_handlers/s3/integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import sys
from http.server import BaseHTTPRequestHandler
from types import SimpleNamespace

import pytest

from pants.backend.url_handlers.s3.register import (
DownloadS3AuthorityPathStyleURL,
DownloadS3AuthorityVirtualHostedStyleURL,
DownloadS3SchemeURL,
)
from pants.backend.url_handlers.s3.register import rules as s3_rules
from pants.engine.fs import Digest, FileDigest, NativeDownloadFile, Snapshot
from pants.engine.rules import QueryRule
from pants.testutil.rule_runner import RuleRunner
from pants.util.contextutil import http_server

DOWNLOADS_FILE_DIGEST = FileDigest(
"8fcbc50cda241aee7238c71e87c27804e7abc60675974eaf6567aa16366bc105", 14
)
DOWNLOADS_EXPECTED_DIRECTORY_DIGEST = Digest(
"4c9cf91fcd7ba1abbf7f9a0a1c8175556a82bee6a398e34db3284525ac24a3ad", 84
)


@pytest.fixture
def rule_runner() -> RuleRunner:
return RuleRunner(
rules=[
*s3_rules(),
QueryRule(Snapshot, [DownloadS3SchemeURL]),
QueryRule(Snapshot, [DownloadS3AuthorityVirtualHostedStyleURL]),
QueryRule(Snapshot, [DownloadS3AuthorityPathStyleURL]),
],
isolated_local_store=True,
)


@pytest.fixture
def monkeypatch_botocore(monkeypatch):
def do_patching(expected_url):
botocore = SimpleNamespace()
botocore.exceptions = SimpleNamespace(NoCredentialsError=Exception)
fake_session = object()
fake_creds = SimpleNamespace(access_key="ACCESS", secret_key="SECRET")
botocore.session = SimpleNamespace(get_session=lambda: fake_session)

def fake_resolver_creator(session):
assert session is fake_session
return SimpleNamespace(load_credentials=lambda: fake_creds)

def fake_creds_ctor(access_key, secret_key):
assert access_key == fake_creds.access_key
assert secret_key == fake_creds.secret_key
return fake_creds

botocore.credentials = SimpleNamespace(
create_credential_resolver=fake_resolver_creator, Credentials=fake_creds_ctor
)

def fake_auth_ctor(creds):
assert creds is fake_creds

def add_auth(request):
request.url == expected_url
request.headers["AUTH"] = "TOKEN"

return SimpleNamespace(add_auth=add_auth)

botocore.auth = SimpleNamespace(HmacV1Auth=fake_auth_ctor)

monkeypatch.setitem(sys.modules, "botocore", botocore)

return do_patching


@pytest.fixture
def replace_url(monkeypatch):
def with_port(expected_url, port):
old_native_download_file_init = NativeDownloadFile.__init__

def new_init(self, **kwargs):
assert kwargs["url"] == expected_url
kwargs["url"] = f"http://localhost:{port}/file.txt"
return old_native_download_file_init(self, **kwargs)

monkeypatch.setattr(NativeDownloadFile, "__init__", new_init)

return with_port


@pytest.mark.parametrize(
"request_url, expected_auth_url, expected_native_url, req_type",
[
(
"s3://bucket/keypart1/keypart2/file.txt",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt",
DownloadS3SchemeURL,
),
# Path-style
(
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt",
DownloadS3AuthorityPathStyleURL,
),
(
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
DownloadS3AuthorityPathStyleURL,
),
(
"https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt",
DownloadS3AuthorityPathStyleURL,
),
(
"https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
DownloadS3AuthorityPathStyleURL,
),
# Virtual-hosted-style
(
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt",
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt",
DownloadS3AuthorityVirtualHostedStyleURL,
),
(
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
DownloadS3AuthorityVirtualHostedStyleURL,
),
(
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
"https://s3.amazonaws.com/bucket/keypart1/keypart2/file.txt?versionId=ABC123",
"https://bucket.s3.amazonaws.com/keypart1/keypart2/file.txt?versionId=ABC123",
DownloadS3AuthorityVirtualHostedStyleURL,
),
(
"https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt",
"https://s3.us-west-2.amazonaws.com/bucket/keypart1/keypart2/file.txt",
"https://bucket.s3.us-west-2.amazonaws.com/keypart1/keypart2/file.txt",
DownloadS3AuthorityVirtualHostedStyleURL,
),
],
)
def test_download_s3(
rule_runner: RuleRunner,
monkeypatch_botocore,
request_url: str,
expected_auth_url: str,
expected_native_url: str,
req_type: type,
replace_url,
) -> None:
class S3HTTPHandler(BaseHTTPRequestHandler):
response_text = b"Hello, client!"

def do_HEAD(self):
self.send_headers()

def do_GET(self):
self.send_headers()
self.wfile.write(self.response_text)

def send_headers(self):
assert self.headers["AUTH"] == "TOKEN"
self.send_response(200)
self.send_header("Content-Type", "binary/octet-stream")
self.send_header("Content-Length", f"{len(self.response_text)}")
self.end_headers()

monkeypatch_botocore(expected_auth_url)
with http_server(S3HTTPHandler) as port:
replace_url(expected_native_url, port)
snapshot = rule_runner.request(
Snapshot,
[req_type(request_url, DOWNLOADS_FILE_DIGEST)],
)
assert snapshot.files == ("file.txt",)
assert snapshot.digest == DOWNLOADS_EXPECTED_DIRECTORY_DIGEST
175 changes: 175 additions & 0 deletions src/python/pants/backend/url_handlers/s3/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import logging
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
from urllib.parse import urlsplit

from pants.engine.download_file import URLDownloadHandler
from pants.engine.fs import Digest, NativeDownloadFile
from pants.engine.internals.native_engine import FileDigest
from pants.engine.internals.selectors import Get
from pants.engine.rules import collect_rules, rule
from pants.engine.unions import UnionRule
from pants.util.strutil import softwrap

CONTENT_TYPE = "binary/octet-stream"


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class AWSCredentials:
creds: Any


@rule
async def access_aws_credentials() -> AWSCredentials:
try:
from botocore import credentials, session
except ImportError:
logger.warning(
softwrap(
"""
In order to resolve s3:// URLs, Pants must load AWS credentials. To do so, `botocore`
must be importable in Pants' environment.

To do that add an entry to `[GLOBAL].plugins` of a pip-resolvable package to download from PyPI.
(E.g. `botocore == 1.29.39`). Note that the `botocore` package from PyPI at the time
of writing is >70MB, so an alternate package providing the `botocore` modules may be
advisable.
"""
)
)
raise

session = session.get_session()
creds = credentials.create_credential_resolver(session).load_credentials()

return AWSCredentials(creds)


@dataclass(frozen=True)
class S3DownloadFile:
region: str
bucket: str
key: str
query: str
expected_digest: FileDigest


@rule
async def download_from_s3(request: S3DownloadFile, aws_credentials: AWSCredentials) -> Digest:
from botocore import auth, exceptions # pants: no-infer-dep

# NB: The URL for auth is expected to be in path-style
path_style_url = "https://s3"
if request.region:
path_style_url += f".{request.region}"
path_style_url += f".amazonaws.com/{request.bucket}/{request.key}"
if request.query:
path_style_url += f"?{request.query}"

http_request = SimpleNamespace(
url=path_style_url,
headers={},
method="GET",
auth_path=None,
)
# NB: The added Auth header doesn't need to be valid when accessing a public bucket. When
# hand-testing, you MUST test against a private bucket to ensure it works for private buckets too.
signer = auth.HmacV1Auth(aws_credentials.creds)
try:
signer.add_auth(http_request)
except exceptions.NoCredentialsError:
pass # The user can still access public S3 buckets without credentials

virtual_hosted_url = f"https://{request.bucket}.s3"
if request.region:
virtual_hosted_url += f".{request.region}"
virtual_hosted_url += f".amazonaws.com/{request.key}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is request.key guaranteed to not start with a /?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I should add an assert

if request.query:
virtual_hosted_url += f"?{request.query}"

return await Get(
Digest,
NativeDownloadFile(
url=virtual_hosted_url,
expected_digest=request.expected_digest,
auth_headers=http_request.headers,
),
)


class DownloadS3SchemeURL(URLDownloadHandler):
match_scheme = "s3"


@rule
async def download_file_from_s3_scheme(request: DownloadS3SchemeURL) -> Digest:
split = urlsplit(request.url)
return await Get(
Digest,
S3DownloadFile(
region="",
bucket=split.netloc,
key=split.path[1:],
query="",
expected_digest=request.expected_digest,
),
)


class DownloadS3AuthorityVirtualHostedStyleURL(URLDownloadHandler):
match_authority = "*.s3*amazonaws.com"


@rule
async def download_file_from_virtual_hosted_s3_authority(
request: DownloadS3AuthorityVirtualHostedStyleURL, aws_credentials: AWSCredentials
) -> Digest:
split = urlsplit(request.url)
bucket, aws_netloc = split.netloc.split(".", 1)
return await Get(
Digest,
S3DownloadFile(
region=aws_netloc.split(".")[1] if aws_netloc.count(".") == 3 else "",
bucket=bucket,
key=split.path[1:],
query=split.query,
expected_digest=request.expected_digest,
),
)


class DownloadS3AuthorityPathStyleURL(URLDownloadHandler):
match_authority = "s3.*amazonaws.com"


@rule
async def download_file_from_path_s3_authority(
request: DownloadS3AuthorityPathStyleURL, aws_credentials: AWSCredentials
) -> Digest:
split = urlsplit(request.url)
_, bucket, key = split.path.split("/", 2)
return await Get(
Digest,
S3DownloadFile(
region=split.netloc.split(".")[1] if split.netloc.count(".") == 3 else "",
bucket=bucket,
key=key,
query=split.query,
expected_digest=request.expected_digest,
),
)


def rules():
return [
UnionRule(URLDownloadHandler, DownloadS3SchemeURL),
UnionRule(URLDownloadHandler, DownloadS3AuthorityVirtualHostedStyleURL),
UnionRule(URLDownloadHandler, DownloadS3AuthorityPathStyleURL),
*collect_rules(),
]
Loading