diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 54f00e00907c0..5b97a0eba0391 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -24,6 +24,7 @@ from airflow.hooks.base_hook import BaseHook from requests import exceptions as requests_exceptions from requests.auth import AuthBase +from time import sleep from airflow.utils.log.logging_mixin import LoggingMixin @@ -47,7 +48,8 @@ def __init__( self, databricks_conn_id='databricks_default', timeout_seconds=180, - retry_limit=3): + retry_limit=3, + retry_delay=1.0): """ :param databricks_conn_id: The name of the databricks connection to use. :type databricks_conn_id: string @@ -57,6 +59,9 @@ def __init__( :param retry_limit: The number of times to retry the connection in case of service outages. :type retry_limit: int + :param retry_delay: The number of seconds to wait between retries (it + might be a floating point number). + :type retry_delay: float """ self.databricks_conn_id = databricks_conn_id self.databricks_conn = self.get_connection(databricks_conn_id) @@ -64,6 +69,7 @@ def __init__( if retry_limit < 1: raise ValueError('Retry limit must be greater than equal to 1') self.retry_limit = retry_limit + self.retry_delay = retry_delay @staticmethod def _parse_host(host): @@ -119,7 +125,8 @@ def _do_api_call(self, endpoint_info, json): else: raise AirflowException('Unexpected HTTP Method: ' + method) - for attempt_num in range(1, self.retry_limit + 1): + attempt_num = 1 + while True: try: response = request_func( url, @@ -127,21 +134,29 @@ def _do_api_call(self, endpoint_info, json): auth=auth, headers=USER_AGENT_HEADER, timeout=self.timeout_seconds) - if response.status_code == requests.codes.ok: - return response.json() - else: + response.raise_for_status() + return response.json() + except requests_exceptions.RequestException as e: + if not _retryable_error(e): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException('Response: {0}, Status Code: {1}'.format( - response.content, response.status_code)) - except (requests_exceptions.ConnectionError, - requests_exceptions.Timeout) as e: - self.log.error( - 'Attempt %s API Request to Databricks failed with reason: %s', - attempt_num, e - ) - raise AirflowException(('API requests to Databricks failed {} times. ' + - 'Giving up.').format(self.retry_limit)) + e.response.content, e.response.status_code)) + + self._log_request_error(attempt_num, e) + + if attempt_num == self.retry_limit: + raise AirflowException(('API requests to Databricks failed {} times. ' + + 'Giving up.').format(self.retry_limit)) + + attempt_num += 1 + sleep(self.retry_delay) + + def _log_request_error(self, attempt_num, error): + self.log.error( + 'Attempt %s API Request to Databricks failed with reason: %s', + attempt_num, error + ) def submit_run(self, json): """ @@ -175,6 +190,12 @@ def cancel_run(self, run_id): self._do_api_call(CANCEL_RUN_ENDPOINT, json) +def _retryable_error(exception): + return isinstance(exception, requests_exceptions.ConnectionError) \ + or isinstance(exception, requests_exceptions.Timeout) \ + or exception.response is not None and exception.response.status_code >= 500 + + RUN_LIFE_CYCLE_STATES = [ 'PENDING', 'RUNNING', diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 7b8d522dba85b..3245a99256502 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator): :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :type databricks_retry_limit: int + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :type databricks_retry_delay: float :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: boolean """ @@ -168,6 +171,7 @@ def __init__( databricks_conn_id='databricks_default', polling_period_seconds=30, databricks_retry_limit=3, + databricks_retry_delay=1, do_xcom_push=False, **kwargs): """ @@ -178,6 +182,7 @@ def __init__( self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay if spark_jar_task is not None: self.json['spark_jar_task'] = spark_jar_task if notebook_task is not None: @@ -232,7 +237,8 @@ def _log_run_page_url(self, url): def get_hook(self): return DatabricksHook( self.databricks_conn_id, - retry_limit=self.databricks_retry_limit) + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay) def execute(self, context): hook = self.get_hook() diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index aca8dd96004b4..a022431899b4d 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -18,15 +18,21 @@ # under the License. # +import itertools import json import unittest +from requests import exceptions as requests_exceptions + from airflow import __version__ -from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth +from airflow.contrib.hooks.databricks_hook import ( + DatabricksHook, + RunState, + SUBMIT_RUN_ENDPOINT +) from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db -from requests import exceptions as requests_exceptions try: from unittest import mock @@ -79,12 +85,48 @@ def get_run_endpoint(host): """ return 'https://{}/api/2.0/jobs/runs/get'.format(host) + def cancel_run_endpoint(host): """ Utility function to generate the get run endpoint given the host. """ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +def create_valid_response_mock(content): + response = mock.MagicMock() + response.json.return_value = content + return response + + +def create_post_side_effect(exception, status_code=500): + if exception != requests_exceptions.HTTPError: + return exception() + else: + response = mock.MagicMock() + response.status_code = status_code + response.raise_for_status.side_effect = exception(response=response) + return response + + +def setup_mock_requests( + mock_requests, + exception, + status_code=500, + error_count=None, + response_content=None): + + side_effect = create_post_side_effect(exception, status_code) + + if error_count is None: + # POST requests will fail indefinitely + mock_requests.post.side_effect = itertools.repeat(side_effect) + else: + # POST requests will fail 'error_count' times, and then they will succeed (once) + mock_requests.post.side_effect = \ + [side_effect] * error_count + [create_valid_response_mock(response_content)] + + class DatabricksHookTest(unittest.TestCase): """ Tests for DatabricksHook. @@ -99,7 +141,7 @@ def setUp(self, session=None): conn.password = PASSWORD session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_delay=0) def test_parse_host_with_proper_host(self): host = self.hook._parse_host(HOST) @@ -111,34 +153,85 @@ def test_parse_host_with_scheme(self): def test_init_bad_retry_limit(self): with self.assertRaises(ValueError): - DatabricksHook(retry_limit = 0) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests): - for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: - with mock.patch.object(self.hook.log, 'error') as mock_errors: - mock_requests.reset_mock() - mock_requests.post.side_effect = exception() + DatabricksHook(retry_limit=0) + + def test_do_api_call_retries_with_retryable_error(self): + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) + self.assertEquals(mock_errors.call_count, self.hook.retry_limit) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_bad_status_code(self, mock_requests): - mock_requests.codes.ok = 200 - status_code_mock = mock.PropertyMock(return_value=500) - type(mock_requests.post.return_value).status_code = status_code_mock - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): + setup_mock_requests( + mock_requests, requests_exceptions.HTTPError, status_code=400 + ) + + with mock.patch.object(self.hook.log, 'error') as mock_errors: + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + mock_errors.assert_not_called() + + def test_do_api_call_succeeds_after_retrying(self): + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests( + mock_requests, + exception, + error_count=2, + response_content={'run_id': '1'} + ) + + response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(mock_errors.call_count, 2) + self.assertEquals(response, {'run_id': '1'}) + + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_waits_between_retries(self, mock_sleep): + retry_delay = 5 + self.hook = DatabricksHook(retry_delay=retry_delay) + + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error'): + mock_sleep.reset_mock() + setup_mock_requests(mock_requests, exception) + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) + mock_sleep.assert_called_with(retry_delay) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_submit_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock json = { 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER @@ -158,10 +251,7 @@ def test_submit_run(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_page_url(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_page_url = self.hook.get_run_page_url(RUN_ID) @@ -175,10 +265,7 @@ def test_get_run_page_url(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_state(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_state = self.hook.get_run_state(RUN_ID) @@ -195,10 +282,7 @@ def test_get_run_state(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_cancel_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock self.hook.cancel_run(RUN_ID) diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index f77da2ec18eda..afe1a92f28d9e 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class): 'run_name': TASK_ID }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class): 'run_name': TASK_ID, }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID)