Skip to content

Commit

Permalink
[AIRFLOW-2670] Update SSH Operator's Hook to respect timeout (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Noremac201 authored and lxneng committed Aug 10, 2018
1 parent 2bc7d95 commit 6fbdbce
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
7 changes: 4 additions & 3 deletions airflow/contrib/operators/ssh_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,17 @@ def __init__(self,
def execute(self, context):
try:
if self.ssh_conn_id and not self.ssh_hook:
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id,
timeout=self.timeout)

if not self.ssh_hook:
raise AirflowException("can not operate without ssh_hook or ssh_conn_id")
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

if self.remote_host is not None:
self.ssh_hook.remote_host = self.remote_host

if not self.command:
raise AirflowException("no command specified so nothing to execute here.")
raise AirflowException("SSH command not specified. Aborting.")

with self.ssh_hook.get_conn() as ssh_client:
# Auto apply tty when its required in case of sudo
Expand Down
21 changes: 19 additions & 2 deletions tests/contrib/operators/test_ssh_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -58,6 +58,23 @@ def setUp(self):
self.hook = hook
self.dag = dag

def test_hook_created_correctly(self):
TIMEOUT = 20
SSH_ID = "ssh_default"
task = SSHOperator(
task_id="test",
command="echo -n airflow",
dag=self.dag,
timeout=TIMEOUT,
ssh_conn_id="ssh_default"
)
self.assertIsNotNone(task)

task.execute(None)

self.assertEquals(TIMEOUT, task.ssh_hook.timeout)
self.assertEquals(SSH_ID, task.ssh_hook.ssh_conn_id)

def test_json_command_execution(self):
configuration.conf.set("core", "enable_xcom_pickling", "False")
task = SSHOperator(
Expand Down

0 comments on commit 6fbdbce

Please sign in to comment.