Skip to content

Commit

Permalink
Add exception and test case
Browse files Browse the repository at this point in the history
  • Loading branch information
aviemzur committed Dec 25, 2019
1 parent 7c316cb commit d5d7fc0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions airflow/contrib/hooks/emr_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):

if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]['Id']
self.log.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id))
self.log.info(f'Found cluster name = {emr_cluster_name} id = {cluster_id}')
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name)
raise AirflowException(f'More than one cluster found for name {emr_cluster_name}')
else:
self.log.info(f'No cluster found for name {emr_cluster_name}')
return None

def create_job_flow(self, job_flow_overrides):
Expand Down
2 changes: 2 additions & 0 deletions airflow/contrib/operators/emr_add_steps_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def execute(self, context):

if not job_flow_id:
job_flow_id = emr_hook.get_cluster_id_by_name(self.job_flow_name, self.cluster_states)
if not job_flow_id:
raise AirflowException('No cluster found for name = %s' % self.job_flow_name)

if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
Expand Down
20 changes: 20 additions & 0 deletions tests/contrib/operators/test_emr_add_steps_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow import DAG
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.utils import timezone

Expand Down Expand Up @@ -128,6 +129,25 @@ def test_init_with_cluster_name(self):

ti.xcom_push.assert_called_once_with(key='job_flow_id', value=expected_job_flow_id)

def test_init_with_nonexistent_cluster_name(self):
cluster_name = 'test_cluster'

with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = None

operator = EmrAddStepsOperator(
task_id='test_task',
job_flow_name=cluster_name,
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args)
)

with self.assertRaises(AirflowException) as error:
operator.execute(self.mock_context)
self.assertEqual(str(error.exception), 'No cluster found for name = test_cluster')


if __name__ == '__main__':
unittest.main()

0 comments on commit d5d7fc0

Please sign in to comment.