diff --git a/records_mover/db/bigquery/loader.py b/records_mover/db/bigquery/loader.py index 7b4dfa4f9..feee793f9 100644 --- a/records_mover/db/bigquery/loader.py +++ b/records_mover/db/bigquery/loader.py @@ -20,6 +20,9 @@ from ..loader import LoaderFromFileobj from ..errors import NoTemporaryBucketConfiguration from ...check_db_conn_engine import check_db_conn_engine +from ..db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) logger = logging.getLogger(__name__) @@ -39,19 +42,9 @@ def __init__(self, self.gcs_temp_base_loc = gcs_temp_base_loc self.conn_opened_here = False - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug(f"Opened connection to database within {self} because none was provided.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn) diff --git a/records_mover/db/db_conn_composable_methods.py b/records_mover/db/db_conn_composable_methods.py new file mode 100644 index 000000000..6bf267e0f --- /dev/null +++ b/records_mover/db/db_conn_composable_methods.py @@ -0,0 +1,23 @@ +import logging +import sqlalchemy +from typing import Optional + +logger = logging.getLogger(__name__) + + +def composable_get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + +def composable_set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + +def composable_del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + self.db_conn = None diff --git a/records_mover/db/driver.py b/records_mover/db/driver.py index ae517ccef..ee6a27e50 100644 --- a/records_mover/db/driver.py +++ b/records_mover/db/driver.py @@ -11,6 +11,9 @@ from abc import ABCMeta, abstractmethod from records_mover.records import RecordsSchema from typing import Union, Dict, List, Tuple, Optional, TYPE_CHECKING +from .db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) if TYPE_CHECKING: from typing_extensions import Literal # noqa @@ -31,19 +34,9 @@ def __init__(self, self.conn_opened_here = False self.meta = MetaData() - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug(f"Opened connection to database within {self} because none was provided.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn) diff --git a/records_mover/db/mysql/loader.py b/records_mover/db/mysql/loader.py index 34f1231c6..9ccc8e4b1 100644 --- a/records_mover/db/mysql/loader.py +++ b/records_mover/db/mysql/loader.py @@ -13,6 +13,9 @@ import logging import tempfile from ...check_db_conn_engine import check_db_conn_engine +from ..db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) logger = logging.getLogger(__name__) @@ -30,19 +33,9 @@ def __init__(self, self.db_engine = db_engine self.url_resolver = url_resolver - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug(f"Opened connection to database within {self} because none was provided.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn) diff --git a/records_mover/db/postgres/loader.py b/records_mover/db/postgres/loader.py index a7b8f0d65..13a5779e3 100644 --- a/records_mover/db/postgres/loader.py +++ b/records_mover/db/postgres/loader.py @@ -14,6 +14,9 @@ from ..loader import LoaderFromFileobj import logging from ...check_db_conn_engine import check_db_conn_engine +from ..db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) logger = logging.getLogger(__name__) @@ -33,19 +36,9 @@ def __init__(self, self.meta = meta self.conn_opened_here = False - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug(f"Opened connection to database within {self} because none was provided.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn) diff --git a/records_mover/db/redshift/loader.py b/records_mover/db/redshift/loader.py index 077092879..44f49c56a 100644 --- a/records_mover/db/redshift/loader.py +++ b/records_mover/db/redshift/loader.py @@ -17,6 +17,9 @@ from botocore.credentials import Credentials from ...records.delimited import complain_on_unhandled_hints from ...check_db_conn_engine import check_db_conn_engine +from ..db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) logger = logging.getLogger(__name__) @@ -36,19 +39,9 @@ def __init__(self, self.s3_temp_base_loc = s3_temp_base_loc self.conn_opened_here = False - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug(f"Opened connection to database within {self} because none was provided.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn) diff --git a/records_mover/db/unloader.py b/records_mover/db/unloader.py index 1a6fea9aa..d1d16c272 100644 --- a/records_mover/db/unloader.py +++ b/records_mover/db/unloader.py @@ -9,6 +9,9 @@ import sqlalchemy from ..check_db_conn_engine import check_db_conn_engine import logging +from .db_conn_composable_methods import (composable_get_db_conn, + composable_set_db_conn, + composable_del_db_conn) logger = logging.getLogger(__name__) @@ -26,19 +29,9 @@ def __init__(self, self.conn_opened_here = False self.meta = MetaData() - def get_db_conn(self) -> sqlalchemy.engine.Connection: - if self._db_conn is None: - self._db_conn = self.db_engine.connect() - self.conn_opened_here = True - logger.debug("Opened connection to database which will be closed inside this uploader.") - return self._db_conn - - def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: - self._db_conn = db_conn - - def del_db_conn(self) -> None: - if self.conn_opened_here: - self.db_conn.close() + get_db_conn = composable_get_db_conn + set_db_conn = composable_set_db_conn + del_db_conn = composable_del_db_conn db_conn = property(get_db_conn, set_db_conn, del_db_conn)