Skip to content

Commit

Permalink
[Providers/HTTP] Add adapter parameter to HttpHook to allow custom re…
Browse files Browse the repository at this point in the history
…quests adapters (#44302)

* feat(http-hook): add adapter parameter to HttpHook and enhance get_conn

- Added `adapter` parameter to `HttpHook` to allow custom HTTP adapters.
- Modified `get_conn` to support mounting custom adapters or using TCPKeepAliveAdapter by default.
- Added comprehensive tests to validate the functionality of the `adapter` parameter and its integration with `get_conn`.
- Ensured all new tests pass and maintain compatibility with existing functionality.

* fix(http_hook): Update docstring and remove redundant TCPKeepAliveAdapter

- Added missing `adapter` parameter description to the HttpHook class docstring.
- Removed redundant instantiation of `TCPKeepAliveAdapter` in the `run` method since it's already instantiated in `get_conn`.

* fix(http_hook): improve get_conn session setup and TCP adapter logic

- Ensured proper mounting of TCP Keep-Alive adapter when enabled.
- Improved handling of connection extras for cleaner session configuration.

* feat(http): update get_conn logic and corresponding tests (#44302)

Aligned the `get_conn` method with the adjustments specified in #44302,
including refined handling of headers. Optimized and updated test cases
to ensure compatibility and maintain robust test coverage.

* refactor(http_hook): simplify HttpHook by reverting BaseAdapter to HTTPAdapter

- Changed the `adapter` parameter to accept only `HTTPAdapter` instead of `BaseAdapter`.
- Strengthened `_set_base_url` validation to ensure base_url is constructed with stricter conditions.
- Adjusted `_mount_adapters` to improve maintainability.

* refactor(http_hook): simplify HttpHook by reverting BaseAdapter to HTTPAdapter

- Changed the `adapter` parameter to accept only `HTTPAdapter` instead of `BaseAdapter`.
- Strengthened `_set_base_url` validation to ensure base_url is constructed with stricter conditions.
- Adjusted `_mount_adapters` to improve maintainability.

* Merge: new main

* refactor: improve function naming and add type annotations

- Changed the function prefix from `_set` to `_configure_session_from` to enhance readability and better reflect its purpose.
- Added static type annotations for input parameters and return values.
- Included comments to document the design rationale following coding standards.
- Improved error message: replaced generic text with detailed and actionable messages.

* fix: simplify the change of session

- Added a variable `session` after the change of session member

* fix: Adjust response format.

* fix: simplify the logic

* fix(hook): ensure default HTTPAdapter in HttpHook init

The `adapter` parameter in `HttpHook` was previously required to be explicitly
set to an instance of `HTTPAdapter`. This commit modifies the `__init__`
method to assign a default `HTTPAdapter` when no adapter is provided.

Changes:
- Removed type checks for `adapter`, as default initialization guarantees correctness.
- Improved code readability and reduced potential runtime errors.

No functional changes beyond defaulting `adapter` to `HTTPAdapter`.

* feat(http_hook): add support for custom adapter in initialization

Refactored `HttpHook` to support a custom `HTTPAdapter` through the `adapter` parameter. If no adapter is provided, it defaults to `TCPKeepAliveAdapter` when `tcp_keep_alive=True`.

Test: Added `test_custom_adapter` to verify correct adapter mounting.

* fix: CI image checks / Static checks

- Adjust the length of each line of code.

* fix: Adjust indent style

- modify `assert instance` by PEP8

* fix: ruff error about `from requests.adapters import HTTPAdapter`

---------

Co-authored-by: jiao <[email protected]>
  • Loading branch information
jieyao-MilestoneHub and jiao authored Dec 3, 2024
1 parent a3a5969 commit 71fec4e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 47 deletions.
126 changes: 80 additions & 46 deletions providers/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse

import aiohttp
import requests
Expand All @@ -34,6 +35,7 @@

if TYPE_CHECKING:
from aiohttp.client_reqrep import ClientResponse
from requests.adapters import HTTPAdapter

from airflow.models import Connection

Expand All @@ -54,6 +56,7 @@ class HttpHook(BaseHook):
API url i.e https://www.google.com/ and optional authentication credentials. Default
headers can also be specified in the Extra field in json format.
:param auth_type: The auth type for the service
:param adapter: An optional instance of `requests.adapters.HTTPAdapter` to mount for the session.
:param tcp_keep_alive: Enable TCP Keep Alive for the connection.
:param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``).
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
Expand All @@ -76,17 +79,25 @@ def __init__(
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
tcp_keep_alive_interval: int = 30,
adapter: HTTPAdapter | None = None,
) -> None:
super().__init__()
self.http_conn_id = http_conn_id
self.method = method.upper()
self.base_url: str = ""
self._retry_obj: Callable[..., Any]
self._auth_type: Any = auth_type
self.tcp_keep_alive = tcp_keep_alive
self.keep_alive_idle = tcp_keep_alive_idle
self.keep_alive_count = tcp_keep_alive_count
self.keep_alive_interval = tcp_keep_alive_interval

# If no adapter is provided, use TCPKeepAliveAdapter (default behavior)
self.adapter = adapter
if tcp_keep_alive and adapter is None:
self.keep_alive_adapter = TCPKeepAliveAdapter(
idle=tcp_keep_alive_idle,
count=tcp_keep_alive_count,
interval=tcp_keep_alive_interval,
)
else:
self.keep_alive_adapter = None

@property
def auth_type(self):
Expand All @@ -102,47 +113,76 @@ def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session:
"""
Create a Requests HTTP session.
:param headers: additional headers to be passed through as a dictionary
:param headers: Additional headers to be passed through as a dictionary.
:return: A configured requests.Session object.
"""
session = requests.Session()

if self.http_conn_id:
conn = self.get_connection(self.http_conn_id)

if conn.host and "://" in conn.host:
self.base_url = conn.host
else:
# schema defaults to HTTP
schema = conn.schema if conn.schema else "http"
host = conn.host if conn.host else ""
self.base_url = f"{schema}://{host}"

if conn.port:
self.base_url += f":{conn.port}"
if conn.login:
session.auth = self.auth_type(conn.login, conn.password)
elif self._auth_type:
session.auth = self.auth_type()
if conn.extra:
extra = conn.extra_dejson
extra.pop(
"timeout", None
) # ignore this as timeout is only accepted in request method of Session
extra.pop("allow_redirects", None) # ignore this as only max_redirects is accepted in Session
session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
session.stream = extra.pop("stream", False)
session.verify = extra.pop("verify", extra.pop("verify_ssl", True))
session.cert = extra.pop("cert", None)
session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = extra.pop("trust_env", True)

try:
session.headers.update(extra)
except TypeError:
self.log.warning("Connection to %s has invalid extra field.", conn.host)
connection = self.get_connection(self.http_conn_id)
self._set_base_url(connection)
session = self._configure_session_from_auth(session, connection)
if connection.extra:
session = self._configure_session_from_extra(session, connection)
session = self._configure_session_from_mount_adapters(session)
if headers:
session.headers.update(headers)
return session

def _set_base_url(self, connection: Connection) -> None:
host = connection.host or ""
schema = connection.schema or "http"
# RFC 3986 (https://www.rfc-editor.org/rfc/rfc3986.html#page-16)
if "://" in host:
self.base_url = host
else:
self.base_url = f"{schema}://{host}" if host else f"{schema}://"
if connection.port:
self.base_url = f"{self.base_url}:{connection.port}"
parsed = urlparse(self.base_url)
if not parsed.scheme:
raise ValueError(f"Invalid base URL: Missing scheme in {self.base_url}")

def _configure_session_from_auth(
self, session: requests.Session, connection: Connection
) -> requests.Session:
session.auth = self._extract_auth(connection)
return session

def _extract_auth(self, connection: Connection) -> Any | None:
if connection.login:
return self.auth_type(connection.login, connection.password)
elif self._auth_type:
return self.auth_type()
return None

def _configure_session_from_extra(
self, session: requests.Session, connection: Connection
) -> requests.Session:
extra = connection.extra_dejson
extra.pop("timeout", None)
extra.pop("allow_redirects", None)
session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
session.stream = extra.pop("stream", False)
session.verify = extra.pop("verify", extra.pop("verify_ssl", True))
session.cert = extra.pop("cert", None)
session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = extra.pop("trust_env", True)
try:
session.headers.update(extra)
except TypeError:
self.log.warning("Connection to %s has invalid extra field.", connection.host)
return session

def _configure_session_from_mount_adapters(self, session: requests.Session) -> requests.Session:
scheme = urlparse(self.base_url).scheme
if not scheme:
raise ValueError(
f"Cannot mount adapters: {self.base_url} does not include a valid scheme (http or https)."
)
if self.adapter:
session.mount(f"{scheme}://", self.adapter)
elif self.keep_alive_adapter:
session.mount("http://", self.keep_alive_adapter)
session.mount("https://", self.keep_alive_adapter)
return session

def run(
Expand Down Expand Up @@ -171,11 +211,6 @@ def run(

url = self.url_from_endpoint(endpoint)

if self.tcp_keep_alive:
keep_alive_adapter = TCPKeepAliveAdapter(
idle=self.keep_alive_idle, count=self.keep_alive_count, interval=self.keep_alive_interval
)
session.mount(url, keep_alive_adapter)
if self.method == "GET":
# GET uses params
req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs)
Expand Down Expand Up @@ -467,5 +502,4 @@ def _retryable_error_async(self, exception: ClientResponseError) -> bool:
if exception.status == 413:
# don't retry for payload Too Large
return False

return exception.status >= 500
16 changes: 15 additions & 1 deletion providers/tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import requests
import tenacity
from aioresponses import aioresponses
from requests.adapters import Response
from requests.adapters import HTTPAdapter, Response
from requests.auth import AuthBase, HTTPBasicAuth
from requests.models import DEFAULT_REDIRECT_LIMIT

Expand Down Expand Up @@ -536,6 +536,20 @@ def test_url_from_endpoint(self, base_url: str, endpoint: str, expected_url: str
hook.base_url = base_url
assert hook.url_from_endpoint(endpoint) == expected_url

def test_custom_adapter(self):
with mock.patch(
"airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port
):
custom_adapter = HTTPAdapter()
hook = HttpHook(method="GET", adapter=custom_adapter)
session = hook.get_conn()
assert isinstance(
session.adapters["http://"], type(custom_adapter)
), "Custom HTTP adapter not correctly mounted"
assert isinstance(
session.adapters["https://"], type(custom_adapter)
), "Custom HTTPS adapter not correctly mounted"


class TestHttpAsyncHook:
@pytest.mark.asyncio
Expand Down

0 comments on commit 71fec4e

Please sign in to comment.