Skip to content

Commit

Permalink
Improve error handling in Databricks hook (apache#3570)
Browse files Browse the repository at this point in the history
* Use float for default value
* Use status code to determine whether an error is retryable
* Fix wrong type in assertion
* Fix style to prevent lines from exceeding 90 characters
* Fix wrong way of checking exception type
  • Loading branch information
betabandido authored and Chris Fei committed Jan 23, 2019
1 parent f4d87b1 commit cb00a50
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 49 deletions.
49 changes: 35 additions & 14 deletions airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.hooks.base_hook import BaseHook
from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
from time import sleep

from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -47,7 +48,8 @@ def __init__(
self,
databricks_conn_id='databricks_default',
timeout_seconds=180,
retry_limit=3):
retry_limit=3,
retry_delay=1.0):
"""
:param databricks_conn_id: The name of the databricks connection to use.
:type databricks_conn_id: string
Expand All @@ -57,13 +59,17 @@ def __init__(
:param retry_limit: The number of times to retry the connection in case of
service outages.
:type retry_limit: int
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:type retry_delay: float
"""
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
self.retry_limit = retry_limit
self.retry_delay = retry_delay

def _parse_host(self, host):
"""
Expand Down Expand Up @@ -118,29 +124,38 @@ def _do_api_call(self, endpoint_info, json):
else:
raise AirflowException('Unexpected HTTP Method: ' + method)

for attempt_num in range(1, self.retry_limit + 1):
attempt_num = 1
while True:
try:
response = request_func(
url,
json=json,
auth=auth,
headers=USER_AGENT_HEADER,
timeout=self.timeout_seconds)
if response.status_code == requests.codes.ok:
return response.json()
else:
response.raise_for_status()
return response.json()
except requests_exceptions.RequestException as e:
if not _retryable_error(e):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException('Response: {0}, Status Code: {1}'.format(
response.content, response.status_code))
except (requests_exceptions.ConnectionError,
requests_exceptions.Timeout) as e:
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, e
)
raise AirflowException(('API requests to Databricks failed {} times. ' +
'Giving up.').format(self.retry_limit))
e.response.content, e.response.status_code))

self._log_request_error(attempt_num, e)

if attempt_num == self.retry_limit:
raise AirflowException(('API requests to Databricks failed {} times. ' +
'Giving up.').format(self.retry_limit))

attempt_num += 1
sleep(self.retry_delay)

def _log_request_error(self, attempt_num, error):
self.log.error(
'Attempt %s API Request to Databricks failed with reason: %s',
attempt_num, error
)

def submit_run(self, json):
"""
Expand Down Expand Up @@ -174,6 +189,12 @@ def cancel_run(self, run_id):
self._do_api_call(CANCEL_RUN_ENDPOINT, json)


def _retryable_error(exception):
return isinstance(exception, requests_exceptions.ConnectionError) \
or isinstance(exception, requests_exceptions.Timeout) \
or exception.response is not None and exception.response.status_code >= 500


RUN_LIFE_CYCLE_STATES = [
'PENDING',
'RUNNING',
Expand Down
8 changes: 7 additions & 1 deletion airflow/contrib/operators/databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:type databricks_retry_limit: int
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:type databricks_retry_delay: float
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
:type do_xcom_push: boolean
"""
Expand All @@ -168,6 +171,7 @@ def __init__(
databricks_conn_id='databricks_default',
polling_period_seconds=30,
databricks_retry_limit=3,
databricks_retry_delay=1,
do_xcom_push=False,
**kwargs):
"""
Expand All @@ -178,6 +182,7 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
if spark_jar_task is not None:
self.json['spark_jar_task'] = spark_jar_task
if notebook_task is not None:
Expand Down Expand Up @@ -232,7 +237,8 @@ def _log_run_page_url(self, url):
def get_hook(self):
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit)
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay)

def execute(self, context):
hook = self.get_hook()
Expand Down
144 changes: 114 additions & 30 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
# under the License.
#

import itertools
import json
import unittest

from requests import exceptions as requests_exceptions

from airflow import __version__
from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth
from airflow.contrib.hooks.databricks_hook import (
DatabricksHook,
RunState,
SUBMIT_RUN_ENDPOINT
)
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils import db
from requests import exceptions as requests_exceptions

try:
from unittest import mock
Expand Down Expand Up @@ -79,12 +85,48 @@ def get_run_endpoint(host):
"""
return 'https://{}/api/2.0/jobs/runs/get'.format(host)


def cancel_run_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)


def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
return response


def create_post_side_effect(exception, status_code=500):
if exception != requests_exceptions.HTTPError:
return exception()
else:
response = mock.MagicMock()
response.status_code = status_code
response.raise_for_status.side_effect = exception(response=response)
return response


def setup_mock_requests(
mock_requests,
exception,
status_code=500,
error_count=None,
response_content=None):

side_effect = create_post_side_effect(exception, status_code)

if error_count is None:
# POST requests will fail indefinitely
mock_requests.post.side_effect = itertools.repeat(side_effect)
else:
# POST requests will fail 'error_count' times, and then they will succeed (once)
mock_requests.post.side_effect = \
[side_effect] * error_count + [create_valid_response_mock(response_content)]


class DatabricksHookTest(unittest.TestCase):
"""
Tests for DatabricksHook.
Expand All @@ -99,7 +141,7 @@ def setUp(self, session=None):
conn.password = PASSWORD
session.commit()

self.hook = DatabricksHook()
self.hook = DatabricksHook(retry_delay=0)

def test_parse_host_with_proper_host(self):
host = self.hook._parse_host(HOST)
Expand All @@ -111,34 +153,85 @@ def test_parse_host_with_scheme(self):

def test_init_bad_retry_limit(self):
with self.assertRaises(ValueError):
DatabricksHook(retry_limit = 0)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_error_retry(self, mock_requests):
for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
with mock.patch.object(self.hook.log, 'error') as mock_errors:
mock_requests.reset_mock()
mock_requests.post.side_effect = exception()
DatabricksHook(retry_limit=0)

def test_do_api_call_retries_with_retryable_error(self):
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error') as mock_errors:
setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)
self.assertEquals(mock_errors.call_count, self.hook.retry_limit)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_with_bad_status_code(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=500)
type(mock_requests.post.return_value).status_code = status_code_mock
with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
setup_mock_requests(
mock_requests, requests_exceptions.HTTPError, status_code=400
)

with mock.patch.object(self.hook.log, 'error') as mock_errors:
with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

mock_errors.assert_not_called()

def test_do_api_call_succeeds_after_retrying(self):
for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error') as mock_errors:
setup_mock_requests(
mock_requests,
exception,
error_count=2,
response_content={'run_id': '1'}
)

response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(mock_errors.call_count, 2)
self.assertEquals(response, {'run_id': '1'})

@mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
def test_do_api_call_waits_between_retries(self, mock_sleep):
retry_delay = 5
self.hook = DatabricksHook(retry_delay=retry_delay)

for exception in [
requests_exceptions.ConnectionError,
requests_exceptions.SSLError,
requests_exceptions.Timeout,
requests_exceptions.ConnectTimeout,
requests_exceptions.HTTPError]:
with mock.patch(
'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
mock.patch.object(self.hook.log, 'error'):
mock_sleep.reset_mock()
setup_mock_requests(mock_requests, exception)

with self.assertRaises(AirflowException):
self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})

self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
mock_sleep.assert_called_with(retry_delay)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_submit_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {'run_id': '1'}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {
'notebook_task': NOTEBOOK_TASK,
'new_cluster': NEW_CLUSTER
Expand All @@ -158,10 +251,7 @@ def test_submit_run(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_page_url(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_page_url = self.hook.get_run_page_url(RUN_ID)

Expand All @@ -175,10 +265,7 @@ def test_get_run_page_url(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_get_run_state(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.get.return_value).status_code = status_code_mock

run_state = self.hook.get_run_state(RUN_ID)

Expand All @@ -195,10 +282,7 @@ def test_get_run_state(self, mock_requests):

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_cancel_run(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

self.hook.cancel_run(RUN_ID)

Expand Down
10 changes: 6 additions & 4 deletions tests/contrib/operators/test_databricks_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class):
'run_name': TASK_ID
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down Expand Up @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class):
'run_name': TASK_ID,
})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit)
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay)
db_mock.submit_run.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_called_once_with(RUN_ID)
Expand Down

0 comments on commit cb00a50

Please sign in to comment.