From 992cb61d89fc2265793a95de2e360777ba7c8d32 Mon Sep 17 00:00:00 2001 From: Ajit J Gupta Date: Mon, 2 Dec 2024 16:44:45 +0530 Subject: [PATCH 1/3] Add host_proxy_cmd parameter to SSHHook and SFTPHook --- .../src/airflow/providers/sftp/hooks/sftp.py | 2 ++ providers/src/airflow/providers/ssh/hooks/ssh.py | 5 +++-- providers/tests/sftp/hooks/test_sftp.py | 16 ++++++++++++++++ providers/tests/ssh/hooks/test_ssh.py | 16 ++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/sftp/hooks/sftp.py b/providers/src/airflow/providers/sftp/hooks/sftp.py index fec11666dec38..1a826cd645c7a 100644 --- a/providers/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/src/airflow/providers/sftp/hooks/sftp.py @@ -83,6 +83,7 @@ def __init__( self, ssh_conn_id: str | None = "sftp_default", ssh_hook: SSHHook | None = None, + host_proxy_cmd: str | None = None, *args, **kwargs, ) -> None: @@ -115,6 +116,7 @@ def __init__( ssh_conn_id = ftp_conn_id kwargs["ssh_conn_id"] = ssh_conn_id + kwargs["host_proxy_cmd"] = host_proxy_cmd self.ssh_conn_id = ssh_conn_id super().__init__(*args, **kwargs) diff --git a/providers/src/airflow/providers/ssh/hooks/ssh.py b/providers/src/airflow/providers/ssh/hooks/ssh.py index 6dd00f2c8bfd2..28501cab14b06 100644 --- a/providers/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/src/airflow/providers/ssh/hooks/ssh.py @@ -124,6 +124,7 @@ def __init__( disabled_algorithms: dict | None = None, ciphers: list[str] | None = None, auth_timeout: int | None = None, + host_proxy_cmd: str | None = None, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -140,7 +141,7 @@ def __init__( self.banner_timeout = banner_timeout self.disabled_algorithms = disabled_algorithms self.ciphers = ciphers - self.host_proxy_cmd = None + self.host_proxy_cmd = host_proxy_cmd self.auth_timeout = auth_timeout # Default values, overridable from Connection @@ -274,7 +275,7 @@ def __init__( with open(user_ssh_config_filename) as config_fd: ssh_conf.parse(config_fd) host_info = ssh_conf.lookup(self.remote_host) - if host_info and host_info.get("proxycommand"): + if host_info and host_info.get("proxycommand") and not self.host_proxy_cmd: self.host_proxy_cmd = host_info["proxycommand"] if not (self.password or self.key_file): diff --git a/providers/tests/sftp/hooks/test_sftp.py b/providers/tests/sftp/hooks/test_sftp.py index 7a7a2991a7031..b4ff76aaff50b 100644 --- a/providers/tests/sftp/hooks/test_sftp.py +++ b/providers/tests/sftp/hooks/test_sftp.py @@ -788,3 +788,19 @@ async def test_get_mod_time_exception(self, mock_hook_get_conn): with pytest.raises(AirflowException) as exc: await hook.get_mod_time("/path/does_not/exist/") assert str(exc.value) == "No files matching" + + @patch("paramiko.SSHClient") + def test_sftp_hook_with_proxy_command(self, mock_ssh_client): + mock_transport = mock.MagicMock() + mock_ssh_client.return_value.get_transport.return_value = mock_transport + + hook = SFTPHook( + remote_host="example.com", + username="user", + host_proxy_cmd="ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p", + ) + hook.get_conn() + + mock_transport.set_proxy.assert_called_once() + proxy_command = mock_transport.set_proxy.call_args[0][0] + assert proxy_command.cmd == "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" diff --git a/providers/tests/ssh/hooks/test_ssh.py b/providers/tests/ssh/hooks/test_ssh.py index 65e1858a6148d..5e992b8d71a76 100644 --- a/providers/tests/ssh/hooks/test_ssh.py +++ b/providers/tests/ssh/hooks/test_ssh.py @@ -1128,3 +1128,19 @@ def test_ssh_connection_client_is_recreated_if_transport_closed(self): client2 = hook.get_conn() assert client1 is not client2 assert client2.get_transport().is_active() + + @mock.patch("paramiko.SSHClient") + def test_ssh_hook_with_proxy_command(self, mock_ssh_client): + mock_transport = mock.MagicMock() + mock_ssh_client.return_value.get_transport.return_value = mock_transport + + hook = SSHHook( + remote_host="example.com", + username="user", + host_proxy_cmd="ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p", + ) + hook.get_conn() + + mock_transport.set_proxy.assert_called_once() + proxy_command = mock_transport.set_proxy.call_args[0][0] + assert proxy_command.cmd == "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" From e236066816476b9741ddc2d56c16db5471b8da4b Mon Sep 17 00:00:00 2001 From: Ajit J Gupta Date: Tue, 3 Dec 2024 08:43:04 +0530 Subject: [PATCH 2/3] Fix unit test case by mocking the paramiko.ProxyCommand --- providers/tests/sftp/hooks/test_sftp.py | 22 +++++++++++++++++----- providers/tests/ssh/hooks/test_ssh.py | 24 +++++++++++++++++++----- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/providers/tests/sftp/hooks/test_sftp.py b/providers/tests/sftp/hooks/test_sftp.py index b4ff76aaff50b..a16a5aec99063 100644 --- a/providers/tests/sftp/hooks/test_sftp.py +++ b/providers/tests/sftp/hooks/test_sftp.py @@ -790,17 +790,29 @@ async def test_get_mod_time_exception(self, mock_hook_get_conn): assert str(exc.value) == "No files matching" @patch("paramiko.SSHClient") - def test_sftp_hook_with_proxy_command(self, mock_ssh_client): + @mock.patch("paramiko.ProxyCommand") + def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): mock_transport = mock.MagicMock() mock_ssh_client.return_value.get_transport.return_value = mock_transport + mock_proxy_command.return_value = mock.MagicMock() + host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" hook = SFTPHook( remote_host="example.com", username="user", - host_proxy_cmd="ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p", + host_proxy_cmd=host_proxy_cmd, ) hook.get_conn() - mock_transport.set_proxy.assert_called_once() - proxy_command = mock_transport.set_proxy.call_args[0][0] - assert proxy_command.cmd == "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + mock_proxy_command.assert_called_once_with(host_proxy_cmd) + mock_ssh_client.return_value.connect.assert_called_once_with( + hostname="example.com", + username="user", + timeout=10, + compress=True, + port=22, + sock=mock_proxy_command.return_value, + look_for_keys=True, + banner_timeout=30.0, + auth_timeout=None, + ) diff --git a/providers/tests/ssh/hooks/test_ssh.py b/providers/tests/ssh/hooks/test_ssh.py index 5e992b8d71a76..518413fd03763 100644 --- a/providers/tests/ssh/hooks/test_ssh.py +++ b/providers/tests/ssh/hooks/test_ssh.py @@ -1130,17 +1130,31 @@ def test_ssh_connection_client_is_recreated_if_transport_closed(self): assert client2.get_transport().is_active() @mock.patch("paramiko.SSHClient") - def test_ssh_hook_with_proxy_command(self, mock_ssh_client): + @mock.patch("paramiko.ProxyCommand") + def test_ssh_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): + # Mock transport and proxy command behavior mock_transport = mock.MagicMock() mock_ssh_client.return_value.get_transport.return_value = mock_transport + mock_proxy_command.return_value = mock.MagicMock() + # Create the SSHHook with the proxy command + host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" hook = SSHHook( remote_host="example.com", username="user", - host_proxy_cmd="ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p", + host_proxy_cmd=host_proxy_cmd, ) hook.get_conn() - mock_transport.set_proxy.assert_called_once() - proxy_command = mock_transport.set_proxy.call_args[0][0] - assert proxy_command.cmd == "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p" + mock_proxy_command.assert_called_once_with(host_proxy_cmd) + mock_ssh_client.return_value.connect.assert_called_once_with( + hostname="example.com", + username="user", + timeout=10, + compress=True, + port=22, + sock=mock_proxy_command.return_value, + look_for_keys=True, + banner_timeout=30.0, + auth_timeout=None, + ) From 56b7cdbb288f25c22ba50d18a11ef2dcf3be99b0 Mon Sep 17 00:00:00 2001 From: Ajit J Gupta Date: Tue, 3 Dec 2024 14:11:38 +0530 Subject: [PATCH 3/3] Fixed test cases by adding timeout=None --- providers/tests/sftp/hooks/test_sftp.py | 2 +- providers/tests/ssh/hooks/test_ssh.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/tests/sftp/hooks/test_sftp.py b/providers/tests/sftp/hooks/test_sftp.py index a16a5aec99063..5f2c34a8cc0e6 100644 --- a/providers/tests/sftp/hooks/test_sftp.py +++ b/providers/tests/sftp/hooks/test_sftp.py @@ -808,7 +808,7 @@ def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client) mock_ssh_client.return_value.connect.assert_called_once_with( hostname="example.com", username="user", - timeout=10, + timeout=None, compress=True, port=22, sock=mock_proxy_command.return_value, diff --git a/providers/tests/ssh/hooks/test_ssh.py b/providers/tests/ssh/hooks/test_ssh.py index c203ca44030e4..e09f2eeee0af7 100644 --- a/providers/tests/ssh/hooks/test_ssh.py +++ b/providers/tests/ssh/hooks/test_ssh.py @@ -977,7 +977,7 @@ def test_ssh_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): mock_ssh_client.return_value.connect.assert_called_once_with( hostname="example.com", username="user", - timeout=10, + timeout=None, compress=True, port=22, sock=mock_proxy_command.return_value,