Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-6432] Raise appropriate exception in EmrAddStepsOperator when using job_flow_name and no cluster is found #6898

5 changes: 3 additions & 2 deletions airflow/contrib/hooks/emr_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):

if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]['Id']
self.log.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id))
self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name)
raise AirflowException('More than one cluster found for name %s', emr_cluster_name)
else:
self.log.info('No cluster found for name %s', emr_cluster_name)
return None

def create_job_flow(self, job_flow_overrides):
Expand Down
8 changes: 5 additions & 3 deletions airflow/contrib/operators/emr_add_steps_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def __init__(
self.steps = steps

def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)

job_flow_id = self.job_flow_id
emr = emr_hook.get_conn()

job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
self.cluster_states)
if not job_flow_id:
job_flow_id = emr.get_cluster_id_by_name(self.job_flow_name, self.cluster_states)
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
Copy link
Member

@turbaszek turbaszek Dec 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I start to think when this exception will be reached? When emr_hook.get_cluster_id_by_name return None. Is it possible or this will fail when calling emr_hook.get_cluster_id_by_name whith self.job_flow_name=None?

Copy link
Member Author

@aviemzur aviemzur Dec 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you ask for a job_flow_name that doesn't exist it returns None
emr_hook.py
test case


if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
Expand Down
37 changes: 30 additions & 7 deletions tests/contrib/operators/test_emr_add_steps_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow import DAG
from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.utils import timezone

Expand Down Expand Up @@ -107,23 +108,45 @@ def test_execute_returns_step_id(self):
def test_init_with_cluster_name(self):
expected_job_flow_id = 'j-1231231234'

self.emr_client_mock.get_cluster_id_by_name.return_value = expected_job_flow_id
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN

with patch('boto3.session.Session', self.boto3_session_mock):
with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = expected_job_flow_id

operator = EmrAddStepsOperator(
task_id='test_task',
job_flow_name='test_cluster',
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args)
)

operator.execute(self.mock_context)

ti = self.mock_context['ti']

ti.xcom_push.assert_called_once_with(key='job_flow_id', value=expected_job_flow_id)

def test_init_with_nonexistent_cluster_name(self):
cluster_name = 'test_cluster'

with patch('airflow.contrib.hooks.emr_hook.EmrHook.get_cluster_id_by_name') \
as mock_get_cluster_id_by_name:
mock_get_cluster_id_by_name.return_value = None

operator = EmrAddStepsOperator(
task_id='test_task',
job_flow_name='test_cluster',
job_flow_name=cluster_name,
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args)
)

operator.execute(self.mock_context)

ti = self.mock_context['ti']

ti.xcom_push.assert_any_call(key='job_flow_id', value=expected_job_flow_id)
with self.assertRaises(AirflowException) as error:
operator.execute(self.mock_context)
self.assertEqual(str(error.exception), f'No cluster found for name: {cluster_name}')


if __name__ == '__main__':
Expand Down