Skip to content

Commit

Permalink
feat(http): update get_conn logic and corresponding tests (apache#44302)
Browse files Browse the repository at this point in the history
Aligned the `get_conn` method with the adjustments specified in apache#44302,
including refined handling of headers. Optimized and updated test cases
to ensure compatibility and maintain robust test coverage.
  • Loading branch information
jiao committed Nov 24, 2024
1 parent 311c5ca commit 3b5ebf1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 42 deletions.
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def auth_type(self, v):

# headers may be passed through directly or in the "extra" field in the connection
# definition
def get_conn(self, headers: dict[Any, Any] = None) -> requests.Session:
def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session:
"""
Create a Requests HTTP session.
Expand Down Expand Up @@ -162,7 +162,7 @@ def _set_extra(self, session: requests.Session, connection) -> None:
session.stream = extra.pop("stream", False)
session.verify = extra.pop("verify", extra.pop("verify_ssl", True))
session.cert = extra.pop("cert", None)
session.max_redirects = extra.pop("max_redirects", requests.adapters.DEFAULT_REDIRECT_LIMIT)
session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = extra.pop("trust_env", True)

try:
Expand Down
85 changes: 45 additions & 40 deletions providers/tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
from http import HTTPStatus
from unittest import mock
from unittest.mock import patch

import pytest
import requests
Expand All @@ -49,6 +48,33 @@ def aioresponse():
yield async_response


@pytest.fixture
def mock_connection():
"""Fixture to provide a mocked connection."""
connection = mock.Mock()
connection.host = "example.com"
connection.schema = "https"
connection.port = None
connection.login = None
connection.password = None
connection.extra = None
return connection


@pytest.fixture
def mock_session():
"""Fixture to provide a mocked requests session."""
with mock.patch("requests.Session") as session:
yield session


@pytest.fixture
def hook_with_mock_connection(mock_connection):
"""Fixture to patch `get_connection` in HttpHook."""
with mock.patch.object(HttpHook, "get_connection", return_value=mock_connection):
yield HttpHook


def get_airflow_connection(conn_id: str = "http_default"):
return Connection(conn_id=conn_id, conn_type="http", host="test:8080/", extra='{"bearer": "test"}')

Expand Down Expand Up @@ -538,62 +564,41 @@ def test_url_from_endpoint(self, base_url: str, endpoint: str, expected_url: str
hook.base_url = base_url
assert hook.url_from_endpoint(endpoint) == expected_url

@pytest.fixture
def custom_adapter(self):
"""Fixture to provide a custom HTTPAdapter."""
return HTTPAdapter()

@pytest.fixture
def mock_session(self):
with patch("requests.Session") as mock_session:
yield mock_session

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("requests.Session")
def test_get_conn_with_custom_adapter(self, mock_session):
def test_get_conn_with_custom_adapter(self, hook_with_mock_connection, mock_session):
"""Test that a custom adapter is correctly mounted to the session."""
custom_adapter = HTTPAdapter()
hook = HttpHook(adapter=custom_adapter)
hook = hook_with_mock_connection(adapter=custom_adapter)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that session.mount was called with the custom adapter
expected_scheme = "https://"
mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter)
expected_scheme = "https" # Set schema in mock_connection
mock_session.return_value.mount.assert_called_with(f"{expected_scheme}://", custom_adapter)

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("requests.Session")
def test_get_conn_without_adapter_uses_default(self, mock_session):
def test_get_conn_without_adapter_uses_default(self, hook_with_mock_connection, mock_session):
"""Test that default TCPKeepAliveAdapter is used when no custom adapter is provided."""
hook = HttpHook(tcp_keep_alive=True)
hook = hook_with_mock_connection(tcp_keep_alive=True)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that TCPKeepAliveAdapter is used
calls = mock_session.return_value.mount.call_args_list
assert len(calls) == 2 # Should mount for 'http://' and 'https://'
adapters_used = [call.args[1] for call in calls]
assert all(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used)
assert len(calls) == 2 # Mount for both 'http://' and 'https://'
adapters = [call.args[1] for call in calls]
assert all(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters)

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
@mock.patch("requests.Session")
def test_get_conn_with_adapter_and_tcp_keep_alive(self, mock_session):
"""Test that when both adapter and tcp_keep_alive are provided, custom adapter is used."""
def test_get_conn_with_adapter_and_tcp_keep_alive(self, hook_with_mock_connection, mock_session):
"""Test that custom adapter is used when both adapter and tcp_keep_alive are provided."""
custom_adapter = HTTPAdapter()
hook = HttpHook(adapter=custom_adapter, tcp_keep_alive=True)
hook = hook_with_mock_connection(adapter=custom_adapter, tcp_keep_alive=True)

# Call get_conn to trigger adapter mounting
hook.get_conn()

# Verify that the custom adapter is used instead of TCPKeepAliveAdapter
expected_scheme = "https://"
mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter)
# Ensure TCPKeepAliveAdapter is not mounted
expected_scheme = "https"
mock_session.return_value.mount.assert_called_with(f"{expected_scheme}://", custom_adapter)

# Ensure TCPKeepAliveAdapter is not used
calls = mock_session.return_value.mount.call_args_list
adapters_used = [call.args[1] for call in calls]
assert not any(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used)
adapters = [call.args[1] for call in calls]
assert not any(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters)

def test_adapter_invalid_type(self):
"""Test that providing an invalid adapter type raises TypeError."""
Expand Down

0 comments on commit 3b5ebf1

Please sign in to comment.