Skip to content

Commit

Permalink
fix: patches mypy errors for typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dcmshi committed Nov 5, 2024
1 parent ebd4214 commit 58a8056
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self.parameters = parameters
self.success = success
self.failure = failure
self.selector = selector
self.selector = selector or itemgetter(0)
self.fail_on_empty = fail_on_empty
self.hook_params = hook_params
super().__init__(**kwargs)
Expand Down
43 changes: 23 additions & 20 deletions providers/src/airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@
Callable,
Iterable,
List,
Optional,
Mapping,
Sequence,
TypeVar,
Union,
cast,
overload,
)

from databricks import sql # type: ignore[attr-defined]

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

Expand All @@ -49,6 +52,7 @@
T = TypeVar("T")



class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Hook to interact with Databricks SQL.
Expand Down Expand Up @@ -90,7 +94,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
self._sql_conn: Optional[Connection] = None
self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
Expand Down Expand Up @@ -130,7 +134,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
else:
return endpoint

def get_conn(self) -> Connection:
def get_conn(self) -> AirflowConnection:
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
Expand All @@ -145,20 +149,15 @@ def get_conn(self) -> Connection:
"or sql_endpoint_name should be specified"
)

requires_init = True
if not self._token:
self._token = self._get_token(raise_error=True)
else:
new_token = self._get_token(raise_error=True)
if new_token != self._token:
self._token = new_token
else:
requires_init = False
prev_token = self._token
new_token = self._get_token(raise_error=True)
if not self._token or new_token != self._token:
self._token = new_token

if not self.session_config:
self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")

if not self._sql_conn or requires_init:
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
self._sql_conn = sql.connect(
Expand All @@ -173,7 +172,10 @@ def get_conn(self) -> Connection:
**self._get_extra_config(),
**self.additional_params,
)
return self._sql_conn

if self._sql_conn is None:
raise AirflowException("SQL connection is not initialized")
return cast(AirflowConnection, self._sql_conn)

@overload # type: ignore[override]
def run(
Expand Down Expand Up @@ -273,22 +275,23 @@ def run(
else:
return results

def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
def _make_common_data_structure(self, result: Union[T, Sequence[T]]) -> Union[tuple[Any, ...], list[tuple[Any, ...]]]:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
rows: list[Row] = result
rows: Sequence[Row] = result
if not rows:
return []
rows_fields = tuple(rows[0].__fields__)
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
return cast(List[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = tuple(row.__fields__)
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
elif isinstance(result, Row):
row_fields = tuple(result.__fields__)
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
return cast(tuple, row_object(*row))
return cast(tuple[Any, ...], row_object(*result))

raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down

0 comments on commit 58a8056

Please sign in to comment.