Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add host_proxy_cmd parameter to SSHHook and SFTPHook #44565

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions providers/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self,
ssh_conn_id: str | None = "sftp_default",
ssh_hook: SSHHook | None = None,
host_proxy_cmd: str | None = None,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
ssh_conn_id = ftp_conn_id

kwargs["ssh_conn_id"] = ssh_conn_id
kwargs["host_proxy_cmd"] = host_proxy_cmd
self.ssh_conn_id = ssh_conn_id

super().__init__(*args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions providers/src/airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
disabled_algorithms: dict | None = None,
ciphers: list[str] | None = None,
auth_timeout: int | None = None,
host_proxy_cmd: str | None = None,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
Expand All @@ -134,7 +135,7 @@ def __init__(
self.banner_timeout = banner_timeout
self.disabled_algorithms = disabled_algorithms
self.ciphers = ciphers
self.host_proxy_cmd = None
self.host_proxy_cmd = host_proxy_cmd
self.auth_timeout = auth_timeout

# Default values, overridable from Connection
Expand Down Expand Up @@ -246,7 +247,7 @@ def __init__(
with open(user_ssh_config_filename) as config_fd:
ssh_conf.parse(config_fd)
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get("proxycommand"):
if host_info and host_info.get("proxycommand") and not self.host_proxy_cmd:
self.host_proxy_cmd = host_info["proxycommand"]

if not (self.password or self.key_file):
Expand Down
28 changes: 28 additions & 0 deletions providers/tests/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,3 +788,31 @@ async def test_get_mod_time_exception(self, mock_hook_get_conn):
with pytest.raises(AirflowException) as exc:
await hook.get_mod_time("/path/does_not/exist/")
assert str(exc.value) == "No files matching"

@patch("paramiko.SSHClient")
@mock.patch("paramiko.ProxyCommand")
def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client):
mock_transport = mock.MagicMock()
mock_ssh_client.return_value.get_transport.return_value = mock_transport
mock_proxy_command.return_value = mock.MagicMock()

host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p"
hook = SFTPHook(
remote_host="example.com",
username="user",
host_proxy_cmd=host_proxy_cmd,
)
hook.get_conn()

mock_proxy_command.assert_called_once_with(host_proxy_cmd)
mock_ssh_client.return_value.connect.assert_called_once_with(
hostname="example.com",
username="user",
timeout=None,
compress=True,
port=22,
sock=mock_proxy_command.return_value,
look_for_keys=True,
banner_timeout=30.0,
auth_timeout=None,
)
30 changes: 30 additions & 0 deletions providers/tests/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,3 +955,33 @@ def test_ssh_connection_client_is_recreated_if_transport_closed(self):
client2 = hook.get_conn()
assert client1 is not client2
assert client2.get_transport().is_active()

@mock.patch("paramiko.SSHClient")
@mock.patch("paramiko.ProxyCommand")
def test_ssh_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client):
# Mock transport and proxy command behavior
mock_transport = mock.MagicMock()
mock_ssh_client.return_value.get_transport.return_value = mock_transport
mock_proxy_command.return_value = mock.MagicMock()

# Create the SSHHook with the proxy command
host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p"
hook = SSHHook(
remote_host="example.com",
username="user",
host_proxy_cmd=host_proxy_cmd,
)
hook.get_conn()

mock_proxy_command.assert_called_once_with(host_proxy_cmd)
mock_ssh_client.return_value.connect.assert_called_once_with(
hostname="example.com",
username="user",
timeout=None,
compress=True,
port=22,
sock=mock_proxy_command.return_value,
look_for_keys=True,
banner_timeout=30.0,
auth_timeout=None,
)