Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plugins to add custom schema/authority URL handler rules #17898

Merged
merged 7 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/python/pants/backend/url_handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
Empty file.
4 changes: 4 additions & 0 deletions src/python/pants/backend/url_handlers/s3/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()
Empty file.
94 changes: 94 additions & 0 deletions src/python/pants/backend/url_handlers/s3/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import logging
from dataclasses import dataclass
from types import SimpleNamespace
from urllib.parse import urlparse

from pants.engine.download_file import URLDownloadHandler
from pants.engine.fs import Digest, NativeDownloadFile
from pants.engine.internals.selectors import Get
from pants.engine.rules import collect_rules, rule
from pants.engine.unions import UnionRule
from pants.util.strutil import softwrap

CONTENT_TYPE = "binary/octet-stream"


logger = logging.getLogger(__name__)


class DownloadS3URLHandler(URLDownloadHandler):
matches_scheme = "s3"


@dataclass(frozen=True)
class AWSCredentials:
access_key_id: str
secret_access_key: str


@rule
async def access_aws_credentials() -> AWSCredentials:
try:
import botocore.credentials # pants: no-infer-dep
import botocore.session # pants: no-infer-dep
except ImportError:
logger.warning(
softwrap(
"""
In order to resolve s3:// URLs, Pants must load AWS credentials. To do so, `botocore`
must be importable in Pants' environment.

To do that add an entry to `[GLOBAL].plugins` of a pip-resolvable package to download from PyPI.
(E.g. `botocore == 1.29.39`). Note that the `botocore` package from PyPI at the time
of writing is >70MB, so an alternate package providing the `botocore` modules may be
advisable.
"""
)
)
raise

session = botocore.session.get_session()
creds = botocore.credentials.create_credential_resolver(session).load_credentials()

return AWSCredentials(
access_key_id=creds.access_key,
secret_access_key=creds.secret_key,
)


@rule
async def download_s3_file(
request: DownloadS3URLHandler, aws_credentials: AWSCredentials
) -> Digest:
import botocore.auth # pants: no-infer-dep
import botocore.credentials # pants: no-infer-dep

boto_creds = botocore.credentials.Credentials(
aws_credentials.access_key_id, aws_credentials.secret_access_key
)
auth = botocore.auth.SigV3Auth(boto_creds)
headers_container = SimpleNamespace(headers={})
auth.add_auth(headers_container)

parsed_url = urlparse(request.url)
bucket = parsed_url.netloc
key = parsed_url.path

digest = await Get(
Digest,
NativeDownloadFile(
url=f"https://{bucket}.s3.amazonaws.com{key}",
expected_digest=request.expected_digest,
auth_headers=headers_container.headers,
),
)
return digest


def rules():
return [
UnionRule(URLDownloadHandler, DownloadS3URLHandler),
*collect_rules(),
]
85 changes: 85 additions & 0 deletions src/python/pants/engine/download_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from dataclasses import dataclass
from typing import ClassVar, Optional
from urllib.parse import urlparse

from pants.engine.fs import Digest, DownloadFile, NativeDownloadFile
from pants.engine.internals.native_engine import FileDigest
from pants.engine.internals.selectors import Get
from pants.engine.rules import collect_rules, rule
from pants.engine.unions import UnionMembership, union


@union
@dataclass(frozen=True)
class URLDownloadHandler:
"""Union base for custom URL handler.

To register a custom URL handler:
- Subclass this class and declare one or both of the ClassVars.
- Declare a rule that takes in your class type and returns a `Digest`.
- Register your union member in your `rules()`: `UnionRule(URLDownloadHandler, YourClass)`.

Example:

class S3DownloadHandler(URLDownloadHandler):
match_scheme = "s3"

@rule
async def download_s3_file(request: S3DownloadHandler) -> Digest:
# Lookup auth tokens, etc...
# Ideally, download the file using `NativeDownloadFile()`
return digest

def rules():
return [
*collect_rules(),
UnionRule(URLDownloadHandler, S3DownloadHandler),
]
"""

match_scheme: ClassVar[Optional[str]] = None
"""The scheme to match (e.g. 'ftp' or 's3') or `None` to match all schemes.

Note that 'http' and 'https' are two different schemes. In order to match either, you'll need to
register both.
"""

match_authority: ClassVar[Optional[str]] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is allowing wildcards. I suspect at <company> we'll want to use https://*.s3.amazonaws.com. I can upstream that as well.

"""The authority to match (e.g. 'pantsbuild.org' or 's3.amazonaws.com') or `None` to match all schemes.

Note that the authority matches userinfo (e.g. '[email protected]' or 'me:[email protected]')
as well as port (e.g. 'pantsbuild.org:80').
"""

url: str
expected_digest: FileDigest


@rule
async def download_file(
request: DownloadFile,
union_membership: UnionMembership,
) -> Digest:
parsed_url = urlparse(request.url)
handlers = union_membership.get(URLDownloadHandler)
thejcannon marked this conversation as resolved.
Show resolved Hide resolved
for handler in handlers:
matches_scheme = handler.match_scheme is None or handler.match_scheme == parsed_url.scheme
matches_authority = (
handler.match_authority is None or handler.match_authority == parsed_url.netloc
)
if matches_scheme or matches_authority:
digest = await Get(
Digest, URLDownloadHandler, handler(request.url, request.expected_digest)
)
break
else:
digest = await Get(Digest, NativeDownloadFile(request.url, request.expected_digest))

return digest


def rules():
return collect_rules()
33 changes: 31 additions & 2 deletions src/python/pants/engine/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Tuple, Union

# Note: several of these types are re-exported as the public API of `engine/fs.py`.
from pants.base.glob_match_error_behavior import GlobMatchErrorBehavior as GlobMatchErrorBehavior
Expand All @@ -23,6 +23,7 @@
from pants.engine.internals.native_engine import RemovePrefix as RemovePrefix
from pants.engine.internals.native_engine import Snapshot as Snapshot
from pants.engine.rules import QueryRule
from pants.util.frozendict import FrozenDict
from pants.util.meta import frozen_after_init

if TYPE_CHECKING:
Expand Down Expand Up @@ -248,6 +249,34 @@ class DownloadFile:
expected_digest: FileDigest


@frozen_after_init
@dataclass(unsafe_hash=True)
class NativeDownloadFile:
"""Retrieve the contents of a file via an HTTP GET request or directly for local file:// URLs.

This request is handled directly by the native engine without any additional coercion by plugins,
and therefore should only be used in cases where the URL is known to be publicly accessible.
Otherwise, callers should use `DownloadFile`.

The auth_headers are part of this nodes' cache key for memoization (changing a header invalidates
prior results) but are not part of the underlying cache key for the local/remote cache (changing
a header won't re-download a file if the file was previously downloaded).
"""
Comment on lines +252 to +264
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this extra overloading is slightly odd, because as you mention, the "native" implementation already supports multiple URL schemes.

What was the reasoning behind doing this in Python via a union, rather than by adding another implementation on the rust side? A preference for boto? A binary size concern? A desire to use the @union to plug in other private implementations? Because it seems like it would be simpler to continue to expand the existing scheme based switching on the Rust side (which would get you retry for the auth portion as well), rather than have two code paths parsing URLs to decide which implementation to use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a closed PR in your inbox that has this on the rust side... But much less powerful.

Three big things stick out:

  • handling auth headers for any URL. As I said in the PR description we might use artifactory, so vanilla http with an auth header is required. That's currently impossible to have rust-side with a flexible enough client-side
  • s3 URLs come in many flavors as the S3 tests show. Trying to match them all is folly.
  • bloating the engine binary with AWS is just plain rude

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be clearer if we just made this type private?

Copy link
Member

@stuhood stuhood Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be clearer if we just made this type private?

Yea, I would feel better about that. I know that we don't have any explicitly public / stable plugin APIs, but it'd be good not to commit to this being an extension point if we think that support for s3 / auth are things that we should support natively (I think that they are.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, the native type would be private. The extension point would remain public.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you use the extension point if the native type was private?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same way I am today, but with an underscore 😌

More honestly, we'd just keep usage internal to the repo I guess

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh maybe there's confusion. Plugin code is currently using DownloadFile, which will become the extension point. Then the extender(?) Is the one using the "native" type.

I have an example usage in this PR for S3, check it out!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and maybe you didn't see, the plugin implementation boils down to NativeDownloadFile, so we still get retry and fast-return-if-cached


url: str
expected_digest: FileDigest
# NB: This mapping can be of any arbitrary headers, but should be limited to those required for
# authorization.
auth_headers: FrozenDict[str, str]

def __init__(
self, url: str, expected_digest: FileDigest, auth_headers: Mapping[str, str] | None = None
) -> None:
self.url = url
self.expected_digest = expected_digest
self.auth_headers = FrozenDict(auth_headers or {})


@dataclass(frozen=True)
class Workspace(SideEffecting):
"""A handle for operations that mutate the local filesystem."""
Expand Down Expand Up @@ -300,7 +329,7 @@ def rules():
QueryRule(Digest, (PathGlobs,)),
QueryRule(Digest, (AddPrefix,)),
QueryRule(Digest, (RemovePrefix,)),
QueryRule(Digest, (DownloadFile,)),
QueryRule(Digest, (NativeDownloadFile,)),
QueryRule(Digest, (MergeDigests,)),
QueryRule(Digest, (DigestSubset,)),
QueryRule(DigestContents, (Digest,)),
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/engine/internals/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
DigestEntries,
DigestSubset,
Directory,
DownloadFile,
FileContent,
FileDigest,
FileEntry,
NativeDownloadFile,
PathGlobs,
PathGlobsAndRoot,
Paths,
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
path_globs=PathGlobs,
create_digest=CreateDigest,
digest_subset=DigestSubset,
download_file=DownloadFile,
native_download_file=NativeDownloadFile,
platform=Platform,
process=Process,
process_result=FallibleProcessResult,
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/init/engine_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pants.build_graph.build_configuration import BuildConfiguration
from pants.core.util_rules import environments, system_binaries
from pants.core.util_rules.environments import determine_bootstrap_environment
from pants.engine import desktop, fs, process
from pants.engine import desktop, download_file, fs, process
from pants.engine.console import Console
from pants.engine.environment import EnvironmentName
from pants.engine.fs import PathGlobs, Snapshot, Workspace
Expand Down Expand Up @@ -272,6 +272,7 @@ def build_root_singleton() -> BuildRoot:
*fs.rules(),
*dep_rules.rules(),
*desktop.rules(),
*download_file.rules(),
*git_rules(),
*graph.rules(),
*specs_rules.rules(),
Expand Down
26 changes: 24 additions & 2 deletions src/rust/engine/src/downloads.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2021 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).

use std::collections::BTreeMap;
use std::io::{self, Write};
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -10,6 +11,7 @@ use bytes::{BufMut, Bytes};
use futures::stream::StreamExt;
use hashing::Digest;
use humansize::{file_size_opts, FileSize};
use reqwest::header::{HeaderMap, HeaderName};
use reqwest::Error;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use tokio_retry::RetryIf;
Expand Down Expand Up @@ -44,11 +46,21 @@ impl NetDownload {
async fn start(
core: &Arc<Core>,
url: Url,
auth_headers: BTreeMap<String, String>,
file_name: String,
) -> Result<NetDownload, StreamingError> {
let mut headers = HeaderMap::new();
for (k, v) in auth_headers.iter() {
thejcannon marked this conversation as resolved.
Show resolved Hide resolved
headers.insert(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe avoid the .unwrap and return an error instead (using ? operator to just do it inline)? I.e., .map_err(|err| ...)?.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for the following line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it'd be a dev mistake, so I'd rather crash-and-burn to make it obvious.

v.parse().unwrap(),
);
}

let response = core
.http_client
.get(url.clone())
.headers(headers)
.send()
.await
.map_err(|err| StreamingError::Retryable(format!("Error downloading file: {err}")))
Expand Down Expand Up @@ -127,6 +139,7 @@ impl StreamingDownload for FileDownload {
async fn attempt_download(
core: &Arc<Core>,
url: &Url,
auth_headers: &BTreeMap<String, String>,
file_name: String,
expected_digest: Digest,
) -> Result<(Digest, Bytes), StreamingError> {
Expand All @@ -144,7 +157,7 @@ async fn attempt_download(
}
Box::new(FileDownload::start(url.path(), file_name).await?)
} else {
Box::new(NetDownload::start(core, url.clone(), file_name).await?)
Box::new(NetDownload::start(core, url.clone(), auth_headers.clone(), file_name).await?)
}
};

Expand Down Expand Up @@ -195,6 +208,7 @@ async fn attempt_download(
pub async fn download(
core: Arc<Core>,
url: Url,
auth_headers: BTreeMap<String, String>,
file_name: String,
expected_digest: hashing::Digest,
) -> Result<(), String> {
Expand All @@ -215,7 +229,15 @@ pub async fn download(
let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(4);
RetryIf::spawn(
retry_strategy,
|| attempt_download(&core2, &url, file_name.clone(), expected_digest),
|| {
attempt_download(
&core2,
&url,
&auth_headers,
file_name.clone(),
expected_digest,
)
},
|err: &StreamingError| matches!(err, StreamingError::Retryable(_)),
)
.await
Expand Down
4 changes: 2 additions & 2 deletions src/rust/engine/src/externs/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl PyTypes {
path_globs: &PyType,
create_digest: &PyType,
digest_subset: &PyType,
download_file: &PyType,
native_download_file: &PyType,
platform: &PyType,
process: &PyType,
process_result: &PyType,
Expand Down Expand Up @@ -215,7 +215,7 @@ impl PyTypes {
remove_prefix: TypeId::new(py.get_type::<externs::fs::PyRemovePrefix>()),
create_digest: TypeId::new(create_digest),
digest_subset: TypeId::new(digest_subset),
download_file: TypeId::new(download_file),
native_download_file: TypeId::new(native_download_file),
platform: TypeId::new(platform),
process: TypeId::new(process),
process_result: TypeId::new(process_result),
Expand Down
Loading