Skip to content

Commit

Permalink
feat(http-hook): add adapter parameter to HttpHook and enhance get_conn
Browse files Browse the repository at this point in the history
- Added `adapter` parameter to `HttpHook` to allow custom HTTP adapters.
- Modified `get_conn` to support mounting custom adapters or using TCPKeepAliveAdapter by default.
- Added comprehensive tests to validate the functionality of the `adapter` parameter and its integration with `get_conn`.
- Ensured all new tests pass and maintain compatibility with existing functionality.
  • Loading branch information
jiao committed Nov 23, 2024
1 parent 3c58e01 commit 1697dc2
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
20 changes: 20 additions & 0 deletions providers/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

import asyncio
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse

import aiohttp
import requests
import tenacity
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async
from requests.adapters import BaseAdapter
from requests.auth import HTTPBasicAuth
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(
method: str = "POST",
http_conn_id: str = default_conn_name,
auth_type: Any = None,
adapter: BaseAdapter | None = None,
tcp_keep_alive: bool = True,
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
Expand All @@ -83,6 +86,11 @@ def __init__(
self.base_url: str = ""
self._retry_obj: Callable[..., Any]
self._auth_type: Any = auth_type

if adapter is not None and not isinstance(adapter, BaseAdapter):
raise TypeError("adapter must be an instance of requests.adapters.BaseAdapter")
self.adapter = adapter

self.tcp_keep_alive = tcp_keep_alive
self.keep_alive_idle = tcp_keep_alive_idle
self.keep_alive_count = tcp_keep_alive_count
Expand Down Expand Up @@ -143,6 +151,18 @@ def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session:
if headers:
session.headers.update(headers)

if self.adapter:
scheme = urlparse(self.base_url).scheme if self.base_url else "https"
session.mount(f"{scheme}://", self.adapter)
elif self.tcp_keep_alive:
keep_alive_adapter = TCPKeepAliveAdapter(
idle=self.keep_alive_idle,
count=self.keep_alive_count,
interval=self.keep_alive_interval,
)
session.mount("http://", keep_alive_adapter)
session.mount("https://", keep_alive_adapter)

return session

def run(
Expand Down
63 changes: 62 additions & 1 deletion providers/tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
import os
from http import HTTPStatus
from unittest import mock
from unittest.mock import patch

import pytest
import requests
import tenacity
from aioresponses import aioresponses
from requests.adapters import Response
from requests.adapters import HTTPAdapter, Response
from requests.auth import AuthBase, HTTPBasicAuth
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter

from airflow.exceptions import AirflowException
from airflow.models import Connection
Expand Down Expand Up @@ -536,6 +538,65 @@ def test_url_from_endpoint(self, base_url: str, endpoint: str, expected_url: str
hook.base_url = base_url
assert hook.url_from_endpoint(endpoint) == expected_url

@pytest.fixture
def custom_adapter(self):
"""Fixture to provide a custom HTTPAdapter."""
return HTTPAdapter()

@pytest.fixture
def mock_session(self):
with patch("requests.Session") as mock_session:
yield mock_session

@mock.patch("requests.Session")
def test_get_conn_with_custom_adapter(self, mock_session):
"""Test that a custom adapter is correctly mounted to the session."""
custom_adapter = HTTPAdapter()
hook = HttpHook(adapter=custom_adapter)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that session.mount was called with the custom adapter
expected_scheme = "https://"
mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter)

@mock.patch("requests.Session")
def test_get_conn_without_adapter_uses_default(self, mock_session):
"""Test that default TCPKeepAliveAdapter is used when no custom adapter is provided."""
hook = HttpHook(tcp_keep_alive=True)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that TCPKeepAliveAdapter is used
calls = mock_session.return_value.mount.call_args_list
assert len(calls) == 2 # Should mount for 'http://' and 'https://'
adapters_used = [call.args[1] for call in calls]
assert all(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used)

@mock.patch("requests.Session")
def test_get_conn_with_adapter_and_tcp_keep_alive(self, mock_session):
"""Test that when both adapter and tcp_keep_alive are provided, custom adapter is used."""
custom_adapter = HTTPAdapter()
hook = HttpHook(adapter=custom_adapter, tcp_keep_alive=True)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that the custom adapter is used instead of TCPKeepAliveAdapter
expected_scheme = "https://"
mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter)
# Ensure TCPKeepAliveAdapter is not mounted
calls = mock_session.return_value.mount.call_args_list
adapters_used = [call.args[1] for call in calls]
assert not any(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used)

def test_adapter_invalid_type(self):
"""Test that providing an invalid adapter type raises TypeError."""
with pytest.raises(TypeError, match="adapter must be an instance of requests.adapters.BaseAdapter"):
HttpHook(adapter="not_an_adapter")


class TestHttpAsyncHook:
@pytest.mark.asyncio
Expand Down

0 comments on commit 1697dc2

Please sign in to comment.