From 58a8056144468b311bdf5e7f891586a3ac0bb0a4 Mon Sep 17 00:00:00 2001 From: DShi Date: Fri, 25 Oct 2024 11:19:52 -0400 Subject: [PATCH] fix: patches mypy errors for typing --- .../providers/common/sql/sensors/sql.py | 2 +- .../databricks/hooks/databricks_sql.py | 43 ++++++++++--------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/sensors/sql.py b/providers/src/airflow/providers/common/sql/sensors/sql.py index ac7eb7b4a4005..b213d2da0600d 100644 --- a/providers/src/airflow/providers/common/sql/sensors/sql.py +++ b/providers/src/airflow/providers/common/sql/sensors/sql.py @@ -80,7 +80,7 @@ def __init__( self.parameters = parameters self.success = success self.failure = failure - self.selector = selector + self.selector = selector or itemgetter(0) self.fail_on_empty = fail_on_empty self.hook_params = hook_params super().__init__(**kwargs) diff --git a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py index fed85835d095c..dd485ec586655 100644 --- a/providers/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -26,9 +26,11 @@ Callable, Iterable, List, + Optional, Mapping, Sequence, TypeVar, + Union, cast, overload, ) @@ -36,6 +38,7 @@ from databricks import sql # type: ignore[attr-defined] from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.models.connection import Connection as AirflowConnection from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook @@ -49,6 +52,7 @@ T = TypeVar("T") + class DatabricksSqlHook(BaseDatabricksHook, DbApiHook): """ Hook to interact with Databricks SQL. @@ -90,7 +94,7 @@ def __init__( **kwargs, ) -> None: super().__init__(databricks_conn_id, caller=caller) - self._sql_conn = None + self._sql_conn: Optional[Connection] = None self._token: str | None = None self._http_path = http_path self._sql_endpoint_name = sql_endpoint_name @@ -130,7 +134,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]: else: return endpoint - def get_conn(self) -> Connection: + def get_conn(self) -> AirflowConnection: """Return a Databricks SQL connection object.""" if not self._http_path: if self._sql_endpoint_name: @@ -145,20 +149,15 @@ def get_conn(self) -> Connection: "or sql_endpoint_name should be specified" ) - requires_init = True - if not self._token: - self._token = self._get_token(raise_error=True) - else: - new_token = self._get_token(raise_error=True) - if new_token != self._token: - self._token = new_token - else: - requires_init = False + prev_token = self._token + new_token = self._get_token(raise_error=True) + if not self._token or new_token != self._token: + self._token = new_token if not self.session_config: self.session_config = self.databricks_conn.extra_dejson.get("session_configuration") - if not self._sql_conn or requires_init: + if not self._sql_conn or prev_token != new_token: if self._sql_conn: # close already existing connection self._sql_conn.close() self._sql_conn = sql.connect( @@ -173,7 +172,10 @@ def get_conn(self) -> Connection: **self._get_extra_config(), **self.additional_params, ) - return self._sql_conn + + if self._sql_conn is None: + raise AirflowException("SQL connection is not initialized") + return cast(AirflowConnection, self._sql_conn) @overload # type: ignore[override] def run( @@ -273,22 +275,23 @@ def run( else: return results - def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple: + def _make_common_data_structure(self, result: Union[T, Sequence[T]]) -> Union[tuple[Any, ...], list[tuple[Any, ...]]]: """Transform the databricks Row objects into namedtuple.""" # Below ignored lines respect namedtuple docstring, but mypy do not support dynamically # instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848 if isinstance(result, list): - rows: list[Row] = result + rows: Sequence[Row] = result if not rows: return [] rows_fields = tuple(rows[0].__fields__) rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore - return cast(List[tuple], [rows_object(*row) for row in rows]) - else: - row: Row = result - row_fields = tuple(row.__fields__) + return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows]) + elif isinstance(result, Row): + row_fields = tuple(result.__fields__) row_object = namedtuple("Row", row_fields, rename=True) # type: ignore - return cast(tuple, row_object(*row)) + return cast(tuple[Any, ...], row_object(*result)) + + raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}") def bulk_dump(self, table, tmp_file): raise NotImplementedError()