Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ClientConnectorError to allow for retries #1338

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions astronomer/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import asyncio
import base64
from typing import Any, Dict, Tuple, cast
from typing import Any, Dict, cast

import aiohttp
from aiohttp import ClientResponseError
from aiohttp import ClientConnectorError, ClientResponseError
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
Expand Down Expand Up @@ -50,7 +52,7 @@ async def get_run_state_async(self, run_id: str) -> RunState:

return RunState(life_cycle_state, result_state, state_message)

async def get_run_response(self, run_id: str) -> Dict[str, Any]:
async def get_run_response(self, run_id: str) -> dict[str, Any]:
"""
Makes Async API call to get the run state info.

Expand All @@ -60,7 +62,7 @@ async def get_run_response(self, run_id: str) -> Dict[str, Any]:
response = await self._do_api_call_async(GET_RUN_ENDPOINT, json)
return response

async def get_run_output_response(self, task_run_id: str) -> Dict[str, Any]:
async def get_run_output_response(self, task_run_id: str) -> dict[str, Any]:
"""
Retrieves run output of the run.

Expand All @@ -71,8 +73,8 @@ async def get_run_output_response(self, task_run_id: str) -> Dict[str, Any]:
return run_output

async def _do_api_call_async(
self, endpoint_info: Tuple[str, str], json: Dict[str, Any]
) -> Dict[str, Any]:
self, endpoint_info: tuple[str, str], json: dict[str, Any]
) -> dict[str, Any]:
"""
Utility function to perform an asynchronous API call with retries

Expand Down Expand Up @@ -133,7 +135,7 @@ async def _do_api_call_async(
)
response.raise_for_status()
return cast(Dict[str, Any], await response.json())
except ClientResponseError as e:
except (ClientConnectorError, ClientResponseError) as e:
if not self._retryable_error_async(e):
# In this case, the user probably made a mistake.
# Don't retry rather raise exception
Expand All @@ -150,7 +152,8 @@ async def _do_api_call_async(
attempt_num += 1
await asyncio.sleep(self.retry_delay)

def _retryable_error_async(self, exception: ClientResponseError) -> bool:
@staticmethod
def _retryable_error_async(exception: ClientConnectorError | ClientResponseError) -> bool:
"""
Determines whether or not an exception that was thrown might be successful
on a subsequent attempt.
Expand All @@ -164,4 +167,6 @@ def _retryable_error_async(self, exception: ClientResponseError) -> bool:
:return: if the status is retryable
:rtype: bool
"""
return exception.status >= 500
if isinstance(exception, ClientResponseError):
return exception.status >= 500
return True
5 changes: 5 additions & 0 deletions tests/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import MagicMock

import pytest
from aiohttp import ClientConnectorError
from airflow import __version__ as provider_version
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
Expand Down Expand Up @@ -291,3 +292,7 @@ async def test_get_run_output_response(self, mock_do_api_async):
mock_do_api_async.return_value = MOCK_GET_OUTPUT_RESPONSE
run_output = await hook.get_run_output_response(RUN_ID)
assert run_output == MOCK_GET_OUTPUT_RESPONSE

def test___retryable_error_async_with_client_connector_error(self):
exception = ClientConnectorError(connection_key="", os_error=OSError())
assert DatabricksHookAsync._retryable_error_async(exception) is True
Loading