From 3b5ebf13f95cae08438a92c77ff1a24a54f6daa4 Mon Sep 17 00:00:00 2001 From: jiao Date: Sun, 24 Nov 2024 16:18:41 +0800 Subject: [PATCH] feat(http): update get_conn logic and corresponding tests (#44302) Aligned the `get_conn` method with the adjustments specified in #44302, including refined handling of headers. Optimized and updated test cases to ensure compatibility and maintain robust test coverage. --- .../src/airflow/providers/http/hooks/http.py | 4 +- providers/tests/http/hooks/test_http.py | 85 ++++++++++--------- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/providers/src/airflow/providers/http/hooks/http.py b/providers/src/airflow/providers/http/hooks/http.py index a48267f0aa928..ae4db8198dd92 100644 --- a/providers/src/airflow/providers/http/hooks/http.py +++ b/providers/src/airflow/providers/http/hooks/http.py @@ -116,7 +116,7 @@ def auth_type(self, v): # headers may be passed through directly or in the "extra" field in the connection # definition - def get_conn(self, headers: dict[Any, Any] = None) -> requests.Session: + def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session: """ Create a Requests HTTP session. @@ -162,7 +162,7 @@ def _set_extra(self, session: requests.Session, connection) -> None: session.stream = extra.pop("stream", False) session.verify = extra.pop("verify", extra.pop("verify_ssl", True)) session.cert = extra.pop("cert", None) - session.max_redirects = extra.pop("max_redirects", requests.adapters.DEFAULT_REDIRECT_LIMIT) + session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT) session.trust_env = extra.pop("trust_env", True) try: diff --git a/providers/tests/http/hooks/test_http.py b/providers/tests/http/hooks/test_http.py index d16a4aa92a21b..2edccc1541c71 100644 --- a/providers/tests/http/hooks/test_http.py +++ b/providers/tests/http/hooks/test_http.py @@ -24,7 +24,6 @@ import os from http import HTTPStatus from unittest import mock -from unittest.mock import patch import pytest import requests @@ -49,6 +48,33 @@ def aioresponse(): yield async_response +@pytest.fixture +def mock_connection(): + """Fixture to provide a mocked connection.""" + connection = mock.Mock() + connection.host = "example.com" + connection.schema = "https" + connection.port = None + connection.login = None + connection.password = None + connection.extra = None + return connection + + +@pytest.fixture +def mock_session(): + """Fixture to provide a mocked requests session.""" + with mock.patch("requests.Session") as session: + yield session + + +@pytest.fixture +def hook_with_mock_connection(mock_connection): + """Fixture to patch `get_connection` in HttpHook.""" + with mock.patch.object(HttpHook, "get_connection", return_value=mock_connection): + yield HttpHook + + def get_airflow_connection(conn_id: str = "http_default"): return Connection(conn_id=conn_id, conn_type="http", host="test:8080/", extra='{"bearer": "test"}') @@ -538,62 +564,41 @@ def test_url_from_endpoint(self, base_url: str, endpoint: str, expected_url: str hook.base_url = base_url assert hook.url_from_endpoint(endpoint) == expected_url - @pytest.fixture - def custom_adapter(self): - """Fixture to provide a custom HTTPAdapter.""" - return HTTPAdapter() - - @pytest.fixture - def mock_session(self): - with patch("requests.Session") as mock_session: - yield mock_session - - @mock.patch("airflow.hooks.base.BaseHook.get_connection") - @mock.patch("requests.Session") - def test_get_conn_with_custom_adapter(self, mock_session): + def test_get_conn_with_custom_adapter(self, hook_with_mock_connection, mock_session): """Test that a custom adapter is correctly mounted to the session.""" custom_adapter = HTTPAdapter() - hook = HttpHook(adapter=custom_adapter) + hook = hook_with_mock_connection(adapter=custom_adapter) - # Call get_conn to trigger adapter mounting hook.get_conn() - # Verify that session.mount was called with the custom adapter - expected_scheme = "https://" - mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter) + expected_scheme = "https" # Set schema in mock_connection + mock_session.return_value.mount.assert_called_with(f"{expected_scheme}://", custom_adapter) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") - @mock.patch("requests.Session") - def test_get_conn_without_adapter_uses_default(self, mock_session): + def test_get_conn_without_adapter_uses_default(self, hook_with_mock_connection, mock_session): """Test that default TCPKeepAliveAdapter is used when no custom adapter is provided.""" - hook = HttpHook(tcp_keep_alive=True) + hook = hook_with_mock_connection(tcp_keep_alive=True) - # Call get_conn to trigger adapter mounting hook.get_conn() - # Verify that TCPKeepAliveAdapter is used calls = mock_session.return_value.mount.call_args_list - assert len(calls) == 2 # Should mount for 'http://' and 'https://' - adapters_used = [call.args[1] for call in calls] - assert all(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used) + assert len(calls) == 2 # Mount for both 'http://' and 'https://' + adapters = [call.args[1] for call in calls] + assert all(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters) - @mock.patch("airflow.hooks.base.BaseHook.get_connection") - @mock.patch("requests.Session") - def test_get_conn_with_adapter_and_tcp_keep_alive(self, mock_session): - """Test that when both adapter and tcp_keep_alive are provided, custom adapter is used.""" + def test_get_conn_with_adapter_and_tcp_keep_alive(self, hook_with_mock_connection, mock_session): + """Test that custom adapter is used when both adapter and tcp_keep_alive are provided.""" custom_adapter = HTTPAdapter() - hook = HttpHook(adapter=custom_adapter, tcp_keep_alive=True) + hook = hook_with_mock_connection(adapter=custom_adapter, tcp_keep_alive=True) - # Call get_conn to trigger adapter mounting hook.get_conn() - # Verify that the custom adapter is used instead of TCPKeepAliveAdapter - expected_scheme = "https://" - mock_session.return_value.mount.assert_called_with(expected_scheme, custom_adapter) - # Ensure TCPKeepAliveAdapter is not mounted + expected_scheme = "https" + mock_session.return_value.mount.assert_called_with(f"{expected_scheme}://", custom_adapter) + + # Ensure TCPKeepAliveAdapter is not used calls = mock_session.return_value.mount.call_args_list - adapters_used = [call.args[1] for call in calls] - assert not any(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters_used) + adapters = [call.args[1] for call in calls] + assert not any(isinstance(adapter, TCPKeepAliveAdapter) for adapter in adapters) def test_adapter_invalid_type(self): """Test that providing an invalid adapter type raises TypeError."""