Skip to content

Commit

Permalink
Custom fetch all handler for vertica to not miss errors (apache#34041)
Browse files Browse the repository at this point in the history
* Custom fetch all handler for vertica to not miss errors

* missing parameter

* Fix test (set nextset to none)

* fix static checks

* fix static-check error

* fix static-check error

* rename variable

* add docstring

* fix docstring
  • Loading branch information
darkag authored Sep 6, 2023
1 parent 6b2a0cb commit 5f47e60
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
80 changes: 77 additions & 3 deletions airflow/providers/vertica/hooks/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,45 @@
# under the License.
from __future__ import annotations

from typing import Any, Callable, Iterable, Mapping, overload

from vertica_python import connect

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler


def vertica_fetch_all_handler(cursor) -> list[tuple] | None:
"""
Replace the default DbApiHook fetch_all_handler in order to fix this issue https://github.com/apache/airflow/issues/32993.
Returned value will not change after the initial call of fetch_all_handler, all the remaining code is here
only to make vertica client throws error.
With Vertica, if you run the following sql (with split_statements set to false):
INSERT INTO MyTable (Key, Label) values (1, 'test 1');
INSERT INTO MyTable (Key, Label) values (1, 'test 2');
INSERT INTO MyTable (Key, Label) values (3, 'test 3');
each insert will have its own result set and if you don't try to fetch data of those result sets
you won't detect error on the second insert.
"""
result = fetch_all_handler(cursor)
# loop on all statement result sets to get errors
if cursor.description is not None:
while cursor.nextset():
if cursor.description is not None:
row = cursor.fetchone()
while row:
row = cursor.fetchone()
return result


class VerticaHook(DbApiHook):
"""Interact with Vertica."""
"""
Interact with Vertica.
This hook use a customized version of default fetch_all_handler named vertica_fetch_all_handler.
"""

conn_name_attr = "vertica_conn_id"
default_conn_name = "vertica_default"
Expand All @@ -32,7 +64,7 @@ class VerticaHook(DbApiHook):
supports_autocommit = True

def get_conn(self) -> connect:
"""Return verticaql connection object."""
"""Return vertica connection object."""
conn = self.get_connection(self.vertica_conn_id) # type: ignore
conn_config = {
"user": conn.login,
Expand Down Expand Up @@ -99,3 +131,45 @@ def get_conn(self) -> connect:

conn = connect(**conn_config)
return conn

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> None:
...

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], Any] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> Any | list[Any]:
...

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping | None = None,
handler: Callable[[Any], Any] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> Any | list[Any] | None:
"""
Overwrite the common sql run.
Will automatically replace fetch_all_handler by vertica_fetch_all_handler.
"""
if handler == fetch_all_handler:
handler = vertica_fetch_all_handler
return DbApiHook.run(self, sql, autocommit, parameters, handler, split_statements, return_last)
1 change: 1 addition & 0 deletions tests/providers/vertica/hooks/test_vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def test_get_conn_extra_parameters_cast(self, mock_connect):
class TestVerticaHook:
def setup_method(self):
self.cur = mock.MagicMock(rowcount=0)
self.cur.nextset.side_effect = [None]
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
Expand Down

0 comments on commit 5f47e60

Please sign in to comment.