diff --git a/.circleci/config.yml b/.circleci/config.yml index 8aabcfb8d..9157df5d0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,7 +19,7 @@ commands: default: false python_version: type: string - description: "Version of python to test against" + description: "Version of python to test against." pandas_version: type: string description: "Version of pandas to test against, or empty string for none" @@ -34,7 +34,7 @@ commands: default: "" steps: - restore_cache: - key: deps-v9-<>-<>-<>-<>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + key: deps-v10-<>-<>-<>-<>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} - run: name: Install python deps in venv environment: @@ -73,7 +73,7 @@ commands: fi fi - save_cache: - key: deps-v9-<>-<>-<>-<>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + key: deps-v10-<>-<>-<>-<>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} paths: - "venv" wait_for_db: @@ -182,7 +182,7 @@ jobs: - slack/notify: event: fail branch_pattern: main - channel: engineering-general + channel: records-mover-ci-builds mentions: '<@bruno.castrokarney>' template: basic_fail_1 @@ -467,15 +467,6 @@ workflows: # We try to test against all non-end-of-life Python versions: # # https://devguide.python.org/devcycle/#end-of-life-branches - - test: - name: test-3.7 - extras: '[unittest,typecheck]' - python_version: "3.7" - pandas_version: ">=1.3.5,<2" - context: slack-secrets - filters: - tags: - only: /v\d+\.\d+\.\d+(-[\w]+)?/ - test: name: test-3.8 extras: '[unittest,typecheck]' @@ -690,8 +681,8 @@ workflows: export DB_FACTS_PATH=${PWD}/tests/integration/circleci-dbfacts.yml export RECORDS_MOVER_SESSION_TYPE=env mkdir -p test-reports/itest - cd tests/integration - python3 -m records.multi_db.test_records_table2table + cd tests/integration/records/multi_db + pytest -vvv test_records_table2table.py - integration_test_with_dbs: name: tbl2tbl-itest-old-sqlalchemy extras: '[literally_every_single_database_binary,itest]' @@ -704,8 +695,8 @@ workflows: export DB_FACTS_PATH=${PWD}/tests/integration/circleci-dbfacts.yml export RECORDS_MOVER_SESSION_TYPE=env mkdir -p test-reports/itest - cd tests/integration - python3 -m records.multi_db.test_records_table2table + cd tests/integration/records/multi_db + pytest -vvv test_records_table2table.py - integration_test_with_dbs: name: redshift-test-dataframe-schema-sql-creation extras: '[redshift-binary,itest]' @@ -746,7 +737,6 @@ workflows: - PyPI - slack-secrets requires: - - test-3.7 - test-3.8 - test-3.9 - cli-extra-test diff --git a/metrics/flake8_high_water_mark b/metrics/flake8_high_water_mark index d00491fd7..573541ac9 100644 --- a/metrics/flake8_high_water_mark +++ b/metrics/flake8_high_water_mark @@ -1 +1 @@ -1 +0 diff --git a/records_mover/airflow/hooks/records_hook.py b/records_mover/airflow/hooks/records_hook.py index 6b1b87d9a..ff858c64c 100644 --- a/records_mover/airflow/hooks/records_hook.py +++ b/records_mover/airflow/hooks/records_hook.py @@ -6,8 +6,9 @@ from records_mover.db import DBDriver from records_mover.url.resolver import UrlResolver from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, Union, List, TYPE_CHECKING import sqlalchemy +from ...check_db_conn_engine import check_db_conn_engine try: # Works with Airflow 1 @@ -50,11 +51,15 @@ def _url_resolver(self) -> UrlResolver: gcs_client_getter=lambda: None, gcp_credentials_getter=lambda: None) - def _db_driver(self, db: sqlalchemy.engine.Engine) -> DBDriver: + def _db_driver(self, db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]] = None, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> DBDriver: s3_temp_base_loc = (self._url_resolver.directory_url(self._s3_temp_base_url) if self._s3_temp_base_url else None) - - return db_driver(db=db, url_resolver=self._url_resolver, s3_temp_base_loc=s3_temp_base_loc) + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + return db_driver(db=db, db_conn=db_conn, db_engine=db_engine, + url_resolver=self._url_resolver, s3_temp_base_loc=s3_temp_base_loc) @property def _s3_temp_base_url(self) -> Optional[str]: diff --git a/records_mover/check_db_conn_engine.py b/records_mover/check_db_conn_engine.py new file mode 100644 index 000000000..c0b74a8be --- /dev/null +++ b/records_mover/check_db_conn_engine.py @@ -0,0 +1,31 @@ +# flake8: noqa + +import sqlalchemy +from typing import Union, Optional, Tuple +import warnings + + +def check_db_conn_engine(db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]] = None, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> \ + Tuple[Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + Optional[sqlalchemy.engine.Connection], + sqlalchemy.engine.Engine]: + if db: + warnings.warn("The db argument is deprecated and will be" + "removed in future releases.\n" + "Please use db_conn for Connection objects and db_engine for Engine" + "objects.", + DeprecationWarning) + if not (db or db_conn or db_engine): + raise ValueError("Either db, db_conn, or db_engine must be provided as arguments") + if isinstance(db, sqlalchemy.engine.Connection) and not db_conn: + db_conn = db + if isinstance(db, sqlalchemy.engine.Engine) and not db_engine: + db_engine = db + if not db_engine: + print("db_engine is not provided, so we're creating one from db_conn") + db_engine = db_conn.engine # type: ignore[union-attr] + return (db, db_conn, db_engine) # type: ignore[return-value] diff --git a/records_mover/db/__init__.py b/records_mover/db/__init__.py index 9ea299eb7..aca8c1638 100644 --- a/records_mover/db/__init__.py +++ b/records_mover/db/__init__.py @@ -4,6 +4,6 @@ 'create_sqlalchemy_url', ] -from .driver import DBDriver # noqa +from .driver import DBDriver # noqa from .errors import LoadError # noqa from .connect import create_sqlalchemy_url diff --git a/records_mover/db/bigquery/bigquery_db_driver.py b/records_mover/db/bigquery/bigquery_db_driver.py index 6659aa93d..851ee3b35 100644 --- a/records_mover/db/bigquery/bigquery_db_driver.py +++ b/records_mover/db/bigquery/bigquery_db_driver.py @@ -4,7 +4,7 @@ from ...records.records_format import BaseRecordsFormat, ParquetRecordsFormat, AvroRecordsFormat from ...utils.limits import INT64_MAX, INT64_MIN, FLOAT64_SIGNIFICAND_BITS, num_digits import re -from typing import Optional, Tuple +from typing import Union, Optional, Tuple from ...url.resolver import UrlResolver import sqlalchemy from .loader import BigQueryLoader @@ -19,17 +19,23 @@ class BigQueryDBDriver(DBDriver): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], url_resolver: UrlResolver, gcs_temp_base_loc: Optional[BaseDirectoryUrl] = None, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs: object) -> None: - super().__init__(db) + super().__init__(db, db_conn, db_engine) self._bigquery_loader =\ - BigQueryLoader(db=self.db, + BigQueryLoader(db=db, + db_conn=db_conn, + db_engine=db_engine, url_resolver=url_resolver, gcs_temp_base_loc=gcs_temp_base_loc) self._bigquery_unloader =\ - BigQueryUnloader(db=self.db, + BigQueryUnloader(db=db, + db_conn=db_conn, + db_engine=db_engine, url_resolver=url_resolver, gcs_temp_base_loc=gcs_temp_base_loc) diff --git a/records_mover/db/bigquery/loader.py b/records_mover/db/bigquery/loader.py index 9bee58fe2..7b4dfa4f9 100644 --- a/records_mover/db/bigquery/loader.py +++ b/records_mover/db/bigquery/loader.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import List, IO, Tuple, Optional, Iterator +from typing import Union, List, IO, Tuple, Optional, Iterator from ...url import BaseDirectoryUrl from ...records.delimited import complain_on_unhandled_hints import pprint @@ -19,19 +19,41 @@ import logging from ..loader import LoaderFromFileobj from ..errors import NoTemporaryBucketConfiguration +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) class BigQueryLoader(LoaderFromFileobj): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], url_resolver: UrlResolver, - gcs_temp_base_loc: Optional[BaseDirectoryUrl])\ - -> None: + gcs_temp_base_loc: Optional[BaseDirectoryUrl], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.db = db + self._db_conn = db_conn + self.db_engine = db_engine self.url_resolver = url_resolver self.gcs_temp_base_loc = gcs_temp_base_loc + self.conn_opened_here = False + + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) def best_scheme_to_load_from(self) -> str: return 'gs' @@ -77,7 +99,7 @@ def load_from_fileobj(self, schema: str, table: str, logger.info("Loading from fileobj into BigQuery") # https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table connection: Connection =\ - self.db.engine.raw_connection().connection + self.db_engine.raw_connection().connection # https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html client: Client = connection._client project_id, dataset_id = self._parse_bigquery_schema_name(schema) @@ -122,7 +144,7 @@ def load(self, logger.info("Loading from records directory into BigQuery") # https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table - connection: Connection = self.db.engine.raw_connection().connection + connection: Connection = self.db_engine.raw_connection().connection # https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html client: Client = connection._client project_id, dataset_id = self._parse_bigquery_schema_name(schema) @@ -187,3 +209,6 @@ def known_supported_records_formats_for_load(self) -> List[BaseRecordsFormat]: ParquetRecordsFormat(), AvroRecordsFormat() ] + + def __del__(self) -> None: + self.del_db_conn() diff --git a/records_mover/db/bigquery/unloader.py b/records_mover/db/bigquery/unloader.py index 6ec2d3201..7d8c3ba79 100644 --- a/records_mover/db/bigquery/unloader.py +++ b/records_mover/db/bigquery/unloader.py @@ -1,6 +1,6 @@ import sqlalchemy from contextlib import contextmanager -from typing import List, Iterator, Optional, Tuple +from typing import List, Iterator, Optional, Union, Tuple import logging from google.cloud.bigquery.dbapi.connection import Connection from google.cloud.bigquery.client import Client @@ -12,20 +12,22 @@ from records_mover.records.unload_plan import RecordsUnloadPlan from records_mover.records.records_directory import RecordsDirectory from records_mover.db.errors import NoTemporaryBucketConfiguration +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) class BigQueryUnloader(Unloader): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], url_resolver: UrlResolver, - gcs_temp_base_loc: Optional[BaseDirectoryUrl])\ - -> None: - self.db = db + gcs_temp_base_loc: Optional[BaseDirectoryUrl], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.url_resolver = url_resolver self.gcs_temp_base_loc = gcs_temp_base_loc - super().__init__(db=db) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) def can_unload_format(self, target_records_format: BaseRecordsFormat) -> bool: if isinstance(target_records_format, AvroRecordsFormat): @@ -93,7 +95,7 @@ def unload(self, logger.info("Loading from records directory into BigQuery") # https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table connection: Connection =\ - self.db.engine.raw_connection().connection + self.db_engine.raw_connection().connection # https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html client: Client = connection._client project_id, dataset_id = self._parse_bigquery_schema_name(schema) diff --git a/records_mover/db/driver.py b/records_mover/db/driver.py index 7d7edfac9..ae517ccef 100644 --- a/records_mover/db/driver.py +++ b/records_mover/db/driver.py @@ -1,3 +1,4 @@ +from ..check_db_conn_engine import check_db_conn_engine from sqlalchemy.schema import CreateTable from ..records.records_format import BaseRecordsFormat from .loader import LoaderFromFileobj, LoaderFromRecordsDirectory @@ -9,7 +10,7 @@ from records_mover.db.quoting import quote_group_name, quote_user_name, quote_schema_and_table from abc import ABCMeta, abstractmethod from records_mover.records import RecordsSchema -from typing import Dict, List, Tuple, Optional, TYPE_CHECKING +from typing import Union, Dict, List, Tuple, Optional, TYPE_CHECKING if TYPE_CHECKING: from typing_extensions import Literal # noqa @@ -18,13 +19,36 @@ class DBDriver(metaclass=ABCMeta): def __init__(self, - db: sqlalchemy.engine.Engine, **kwargs) -> None: + db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, + **kwargs) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.db = db - self.db_engine = db.engine + self.db_engine = db_engine + self._db_conn = db_conn + self.conn_opened_here = False self.meta = MetaData() + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) + def has_table(self, schema: str, table: str) -> bool: - return self.db.dialect.has_table(self.db, table_name=table, schema=schema) + return sqlalchemy.inspect(self.db_engine).has_table(table, schema=schema) def table(self, schema: str, @@ -49,33 +73,55 @@ def varchar_length_is_in_chars(self) -> bool: override it to control""" return False - def set_grant_permissions_for_groups(self, schema_name: str, table: str, + def set_grant_permissions_for_groups(self, + schema_name: str, + table: str, groups: Dict[str, List[str]], - db: sqlalchemy.engine.Engine) -> None: - schema_and_table: str = quote_schema_and_table(self.db.engine, schema_name, table) + db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None + ) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + schema_and_table: str = quote_schema_and_table(None, schema_name, table, + db_engine=self.db_engine) for perm_type in groups: groups_list = groups[perm_type] for group in groups_list: - group_name: str = quote_group_name(self.db.engine, group) + group_name: str = quote_group_name(None, group, db_engine=self.db_engine) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {group_name}' - db.execute(perms_sql) + if db_conn: + db_conn.execute(perms_sql) + else: + with db_engine.connect() as conn: + conn.execute(perms_sql) def set_grant_permissions_for_users(self, schema_name: str, table: str, users: Dict[str, List[str]], - db: sqlalchemy.engine.Engine) -> None: - schema_and_table: str = quote_schema_and_table(self.db.engine, schema_name, table) + db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None + ) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + schema_and_table: str = quote_schema_and_table(None, schema_name, table, + db_engine=self.db_engine) for perm_type in users: user_list = users[perm_type] for user in user_list: - user_name: str = quote_user_name(self.db.engine, user) + user_name: str = quote_user_name(self.db_engine, user) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO {user_name}' - db.execute(perms_sql) + if db_conn: + db_conn.execute(perms_sql) + else: + with db_engine.connect() as conn: + conn.execute(perms_sql) def supports_time_type(self) -> bool: return True @@ -190,6 +236,9 @@ def tweak_records_schema_after_unload(self, records_format: BaseRecordsFormat) -> RecordsSchema: return records_schema + def __del__(self) -> None: + self.del_db_conn() + class GenericDBDriver(DBDriver): def loader_from_fileobj(self) -> None: diff --git a/records_mover/db/factory.py b/records_mover/db/factory.py index eba30dbba..48cf6b2d2 100644 --- a/records_mover/db/factory.py +++ b/records_mover/db/factory.py @@ -1,31 +1,38 @@ +# flake8: noqa + from .driver import GenericDBDriver, DBDriver import sqlalchemy +from typing import Union, Optional +from ..check_db_conn_engine import check_db_conn_engine -def db_driver(db: sqlalchemy.engine.Engine, +def db_driver(db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs) -> DBDriver: - engine: sqlalchemy.engine.Engine = db.engine - engine_name: str = engine.name + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + engine_name: str = db_engine.name if engine_name == 'vertica': from .vertica.vertica_db_driver import VerticaDBDriver - return VerticaDBDriver(db, **kwargs) + return VerticaDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) elif engine_name == 'redshift': from .redshift.redshift_db_driver import RedshiftDBDriver - return RedshiftDBDriver(db, **kwargs) + return RedshiftDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) elif engine_name == 'bigquery': from .bigquery.bigquery_db_driver import BigQueryDBDriver - return BigQueryDBDriver(db, **kwargs) + return BigQueryDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) elif engine_name == 'postgresql': from .postgres.postgres_db_driver import PostgresDBDriver - return PostgresDBDriver(db, **kwargs) + return PostgresDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) elif engine_name == 'mysql': from .mysql.mysql_db_driver import MySQLDBDriver - return MySQLDBDriver(db, **kwargs) + return MySQLDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) else: - return GenericDBDriver(db, **kwargs) + return GenericDBDriver(db=db, db_conn=db_conn, db_engine=db_engine, **kwargs) diff --git a/records_mover/db/mysql/loader.py b/records_mover/db/mysql/loader.py index d7cff19a8..34f1231c6 100644 --- a/records_mover/db/mysql/loader.py +++ b/records_mover/db/mysql/loader.py @@ -9,20 +9,43 @@ from .load_options import mysql_load_options from ...records.delimited import complain_on_unhandled_hints from ...url.resolver import UrlResolver -from typing import List +from typing import Union, List, Optional import logging import tempfile +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) class MySQLLoader(LoaderFromRecordsDirectory): def __init__(self, - db: sqlalchemy.engine.Engine, - url_resolver: UrlResolver) -> None: + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], + url_resolver: UrlResolver, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + self.conn_opened_here = False self.db = db + self._db_conn = db_conn + self.db_engine = db_engine self.url_resolver = url_resolver + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) + def load(self, schema: str, table: str, @@ -61,7 +84,7 @@ def load(self, schema_name=schema) logger.info(f"Loading to MySQL with options: {load_options}") logger.info(str(sql)) - self.db.execute(sql) + self.db_conn.execute(sql) logger.info("MySQL LOAD DATA complete.") return None @@ -101,3 +124,6 @@ def known_supported_records_formats_for_load(self) -> List[BaseRecordsFormat]: 'compression': None }), ] + + def __del__(self) -> None: + self.del_db_conn() diff --git a/records_mover/db/mysql/mysql_db_driver.py b/records_mover/db/mysql/mysql_db_driver.py index 66ae14dae..ec8967f7b 100644 --- a/records_mover/db/mysql/mysql_db_driver.py +++ b/records_mover/db/mysql/mysql_db_driver.py @@ -7,7 +7,7 @@ num_digits) from ..driver import DBDriver from .loader import MySQLLoader -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from ..loader import LoaderFromFileobj, LoaderFromRecordsDirectory from ...url.resolver import UrlResolver @@ -17,11 +17,16 @@ class MySQLDBDriver(DBDriver): def __init__(self, - db: sqlalchemy.engine.Engine, url_resolver: UrlResolver, + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]] = None, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs) -> None: - super().__init__(db) + self.conn_opened_here = False + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) self._mysql_loader = MySQLLoader(db=db, + db_conn=db_conn, + db_engine=db_engine, url_resolver=url_resolver) def loader(self) -> Optional[LoaderFromRecordsDirectory]: diff --git a/records_mover/db/postgres/loader.py b/records_mover/db/postgres/loader.py index 180eff63c..a7b8f0d65 100644 --- a/records_mover/db/postgres/loader.py +++ b/records_mover/db/postgres/loader.py @@ -10,9 +10,10 @@ from ...records.processing_instructions import ProcessingInstructions from .sqlalchemy_postgres_copy import copy_from from .copy_options import postgres_copy_from_options -from typing import IO, List, Iterable +from typing import IO, Union, List, Iterable, Optional from ..loader import LoaderFromFileobj import logging +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -21,10 +22,32 @@ class PostgresLoader(LoaderFromFileobj): def __init__(self, url_resolver: UrlResolver, meta: MetaData, - db: sqlalchemy.engine.Engine) -> None: + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.url_resolver = url_resolver self.db = db + self._db_conn = db_conn + self.db_engine = db_engine self.meta = meta + self.conn_opened_here = False + + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) def load_from_fileobj(self, schema: str, @@ -59,31 +82,30 @@ def load_from_fileobjs(self, table_obj = Table(table, self.meta, schema=schema, - autoload_with=self.db) + autoload_with=self.db_engine) - with self.db.engine.begin() as conn: - # https://www.postgresql.org/docs/8.3/sql-set.html - # - # The effects of SET LOCAL last only till the end of the - # current transaction, whether committed or not. A special - # case is SET followed by SET LOCAL within a single - # transaction: the SET LOCAL value will be seen until the end - # of the transaction, but afterwards (if the transaction is - # committed) the SET value will take effect. - date_style = f"ISO, {date_order_style}" - sql = f"SET LOCAL DateStyle = {quote_value(conn, date_style)}" - logger.info(sql) - conn.execute(text(sql)) - - for fileobj in fileobjs: - # Postgres COPY FROM defaults to appending data--we - # let the records Prep class decide what to do about - # the existing table, so it's safe to call this - # multiple times and append until done: - copy_from(fileobj, - table_obj, - conn, - **postgres_options) + # https://www.postgresql.org/docs/8.3/sql-set.html + # + # The effects of SET LOCAL last only till the end of the + # current transaction, whether committed or not. A special + # case is SET followed by SET LOCAL within a single + # transaction: the SET LOCAL value will be seen until the end + # of the transaction, but afterwards (if the transaction is + # committed) the SET value will take effect. + date_style = f"ISO, {date_order_style}" + sql = f"SET LOCAL DateStyle = {quote_value(None, date_style, db_engine=self.db_engine)}" + logger.info(sql) + self.db_conn.execute(text(sql)) + + for fileobj in fileobjs: + # Postgres COPY FROM defaults to appending data--we + # let the records Prep class decide what to do about + # the existing table, so it's safe to call this + # multiple times and append until done: + copy_from(fileobj, + table_obj, + self.db_conn, + **postgres_options) logger.info('Copy complete') def can_load_this_format(self, source_records_format: BaseRecordsFormat) -> bool: @@ -130,3 +152,6 @@ def known_supported_records_formats_for_load(self) -> List[BaseRecordsFormat]: DelimitedRecordsFormat(variant='bigquery', hints={'compression': None}), ] + + def __del__(self) -> None: + self.del_db_conn() diff --git a/records_mover/db/postgres/postgres_db_driver.py b/records_mover/db/postgres/postgres_db_driver.py index 6f8975803..dbc4e8c68 100644 --- a/records_mover/db/postgres/postgres_db_driver.py +++ b/records_mover/db/postgres/postgres_db_driver.py @@ -12,7 +12,7 @@ from ..loader import LoaderFromFileobj, LoaderFromRecordsDirectory from .unloader import PostgresUnloader from ..unloader import Unloader -from typing import Optional, Tuple +from typing import Optional, Tuple, Union logger = logging.getLogger(__name__) @@ -20,14 +20,20 @@ class PostgresDBDriver(DBDriver): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], url_resolver: UrlResolver, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs) -> None: - super().__init__(db) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) self._postgres_loader = PostgresLoader(url_resolver=url_resolver, meta=self.meta, - db=self.db) - self._postgres_unloader = PostgresUnloader(db=self.db) + db=db, + db_conn=db_conn, + db_engine=db_engine) + self._postgres_unloader = PostgresUnloader(db=db, + db_engine=db_engine, + db_conn=db_conn) def loader(self) -> Optional[LoaderFromRecordsDirectory]: return self._postgres_loader diff --git a/records_mover/db/postgres/sqlalchemy_postgres_copy.py b/records_mover/db/postgres/sqlalchemy_postgres_copy.py index 46c4a44be..a83e90b2d 100644 --- a/records_mover/db/postgres/sqlalchemy_postgres_copy.py +++ b/records_mover/db/postgres/sqlalchemy_postgres_copy.py @@ -15,7 +15,7 @@ __version__ = '0.5.0' -def copy_to(source, dest, conn, **flags): +def copy_to(source, dest, engine_or_conn, **flags): """Export a query or select to a file. For flags, see the PostgreSQL documentation at http://www.postgresql.org/docs/9.5/static/sql-copy.html. @@ -24,29 +24,32 @@ def copy_to(source, dest, conn, **flags): with open('/path/to/file.tsv', 'w') as fp: copy_to(select, fp, conn) + query = session.query(MyModel) + with open('/path/to/file/csv', 'w') as fp: + copy_to(query, fp, engine, format='csv', null='.') :param source: SQLAlchemy query or select :param dest: Destination file pointer, in write mode - :param conn: SQLAlchemy connection + :param engine_or_conn: SQLAlchemy engine, connection, or raw_connection :param **flags: Options passed through to COPY - If an existing connection is passed to `conn`, it is the caller's + If an existing connection is passed to `engine_or_conn`, it is the caller's responsibility to commit and close. """ dialect = postgresql.dialect() statement = getattr(source, 'statement', source) compiled = statement.compile(dialect=dialect) - raw_conn, autoclose = raw_connection_from(conn) - cursor = raw_conn.cursor() + conn, autoclose = raw_connection_from(engine_or_conn) + cursor = conn.cursor() query = cursor.mogrify(compiled.string, compiled.params).decode() formatted_flags = '({})'.format(format_flags(flags)) if flags else '' copy = 'COPY ({}) TO STDOUT {}'.format(query, formatted_flags) cursor.copy_expert(copy, dest) if autoclose: - raw_conn.close() + conn.close() -def copy_from(source, dest, conn, columns=(), **flags): +def copy_from(source, dest, engine_or_conn, columns=(), **flags): """Import a table from a file. For flags, see the PostgreSQL documentation at http://www.postgresql.org/docs/9.5/static/sql-copy.html. @@ -54,13 +57,16 @@ def copy_from(source, dest, conn, columns=(), **flags): with open('/path/to/file.tsv') as fp: copy_from(fp, MyTable, conn) + with open('/path/to/file.csv') as fp: + copy_from(fp, MyModel, engine, format='csv') + :param source: Source file pointer, in read mode :param dest: SQLAlchemy model or table - :param conn: SQLAlchemy connection + :param engine_or_conn: SQLAlchemy engine, connection, or raw_connection :param columns: Optional tuple of columns :param **flags: Options passed through to COPY - If an existing connection is passed to `conn`, it is the caller's + If an existing connection is passed to `engine_or_conn`, it is the caller's responsibility to commit and close. The `columns` flag can be set to a tuple of strings to specify the column @@ -68,8 +74,8 @@ def copy_from(source, dest, conn, columns=(), **flags): postgres to ignore the first line of `source`. """ tbl = dest.__table__ if is_model(dest) else dest - raw_conn, autoclose = raw_connection_from(conn) - cursor = raw_conn.cursor() + conn, autoclose = raw_connection_from(engine_or_conn) + cursor = conn.cursor() relation = '.'.join('"{}"'.format(part) for part in (tbl.schema, tbl.name) if part) formatted_columns = '({})'.format(','.join(columns)) if columns else '' formatted_flags = '({})'.format(format_flags(flags)) if flags else '' @@ -80,20 +86,20 @@ def copy_from(source, dest, conn, columns=(), **flags): ) cursor.copy_expert(copy, source) if autoclose: - raw_conn.commit() - raw_conn.close() + conn.commit() + conn.close() -def raw_connection_from(conn): +def raw_connection_from(engine_or_conn): """Extract a raw_connection and determine if it should be automatically closed. Only connections opened by this package will be closed automatically. """ - if hasattr(conn, 'cursor'): - return conn, False - if hasattr(conn, 'connection'): - return conn.connection, False - return conn.raw_connection(), True + if hasattr(engine_or_conn, 'cursor'): + return engine_or_conn, False + if hasattr(engine_or_conn, 'connection'): + return engine_or_conn.connection, False + return engine_or_conn.raw_connection(), True def format_flags(flags): diff --git a/records_mover/db/postgres/sqlalchemy_postgres_copy.pyi b/records_mover/db/postgres/sqlalchemy_postgres_copy.pyi index e9087a08b..cf232fb1a 100644 --- a/records_mover/db/postgres/sqlalchemy_postgres_copy.pyi +++ b/records_mover/db/postgres/sqlalchemy_postgres_copy.pyi @@ -1,10 +1,11 @@ -from typing import Union, IO +from typing import Union, IO, Optional import sqlalchemy def copy_from(source: IO[bytes], dest: sqlalchemy.schema.Table, - conn: sqlalchemy.engine.Connection, + engine_or_conn: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], **flags: object) -> None: ... @@ -12,6 +13,6 @@ def copy_from(source: IO[bytes], def copy_to(source: Union[sqlalchemy.sql.expression.Select, sqlalchemy.orm.query.Query], dest: IO[bytes], - conn: sqlalchemy.engine.Connection, + engine_or_conn: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], **flags: object) -> None: ... diff --git a/records_mover/db/postgres/unloader.py b/records_mover/db/postgres/unloader.py index 07106e132..4392935aa 100644 --- a/records_mover/db/postgres/unloader.py +++ b/records_mover/db/postgres/unloader.py @@ -44,29 +44,28 @@ def unload(self, table_obj = Table(table, self.meta, schema=schema, - autoload_with=self.db) + autoload_with=self.db_engine) - with self.db.engine.begin() as conn: - # https://www.postgresql.org/docs/8.3/sql-set.html - # - # The effects of SET LOCAL last only till the end of the - # current transaction, whether committed or not. A special - # case is SET followed by SET LOCAL within a single - # transaction: the SET LOCAL value will be seen until the end - # of the transaction, but afterwards (if the transaction is - # committed) the SET value will take effect. - date_style = f"{date_output_style}, {date_order_style}" - sql = f"SET LOCAL DateStyle = {quote_value(conn, date_style)}" - logger.info(sql) - conn.execute(text(sql)) + # https://www.postgresql.org/docs/8.3/sql-set.html + # + # The effects of SET LOCAL last only till the end of the + # current transaction, whether committed or not. A special + # case is SET followed by SET LOCAL within a single + # transaction: the SET LOCAL value will be seen until the end + # of the transaction, but afterwards (if the transaction is + # committed) the SET value will take effect. + date_style = f"{date_output_style}, {date_order_style}" + sql = f"SET LOCAL DateStyle = {quote_value(None, date_style, db_engine=self.db_engine)}" + logger.info(sql) + self.db_conn.execute(text(sql)) - filename = unload_plan.records_format.generate_filename('data') - loc = directory.loc.file_in_this_directory(filename) - with loc.open(mode='wb') as fileobj: - copy_to(table_obj.select(), - fileobj, - conn, - **postgres_options) + filename = unload_plan.records_format.generate_filename('data') + loc = directory.loc.file_in_this_directory(filename) + with loc.open(mode='wb') as fileobj: + copy_to(table_obj.select(), + fileobj, + self.db_conn, + **postgres_options) logger.info('Copy complete') directory.save_preliminary_manifest() diff --git a/records_mover/db/quoting.py b/records_mover/db/quoting.py index 107aba211..b651ff6af 100644 --- a/records_mover/db/quoting.py +++ b/records_mover/db/quoting.py @@ -1,9 +1,12 @@ -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection +from typing import Union, Optional +from ..check_db_conn_engine import check_db_conn_engine -def quote_schema_and_table(db: Engine, +def quote_schema_and_table(db: Optional[Union[Connection, Engine]], schema: str, - table: str) -> str: + table: str, + db_engine: Optional[Engine] = None) -> str: """ Prevent SQL injection when we're not able to use bind variables (e.g., passing a table name in SQL). @@ -17,22 +20,32 @@ def quote_schema_and_table(db: Engine, quoted_bobby: "Robert'); DROP TABLE Students;--" """ - dialect = db.dialect + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + dialect = db_engine.dialect return (dialect.preparer(dialect).quote(schema) + '.' + dialect.preparer(dialect).quote(table)) -def quote_table_only(db: Engine, table: str) -> str: - dialect = db.dialect +def quote_table_only(db: Optional[Union[Connection, Engine]], + table: str, + db_engine: Optional[Engine] = None) -> str: + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + + dialect = db_engine.dialect return dialect.preparer(dialect).quote(table) -def quote_column_name(db: Engine, column_name: str) -> str: - dialect = db.dialect +def quote_column_name(db: Optional[Union[Connection, Engine]], + column_name: str, + db_engine: Optional[Engine] = None) -> str: + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + dialect = db_engine.dialect return dialect.preparer(dialect).quote(column_name) -def quote_value(db: Engine, value: str) -> str: +def quote_value(db: Optional[Union[Connection, Engine]], + value: str, + db_engine: Optional[Engine] = None) -> str: """ Prevent SQL injection on literal string values in places when we're not able to use bind variables (e.g., using weird DB-specific @@ -47,15 +60,22 @@ def quote_value(db: Engine, value: str) -> str: quoted_bobby: 'Robert''); DROP TABLE Students;--' """ - dialect = db.dialect + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + dialect = db_engine.dialect return dialect.preparer(dialect, initial_quote="'").quote(value) -def quote_user_name(db: Engine, user_name: str) -> str: - dialect = db.dialect +def quote_user_name(db: Optional[Union[Connection, Engine]], + user_name: str, + db_engine: Optional[Engine] = None) -> str: + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + dialect = db_engine.dialect return dialect.preparer(dialect).quote_identifier(user_name) -def quote_group_name(db: Engine, group_name: str) -> str: - dialect = db.dialect +def quote_group_name(db: Optional[Union[Connection, Engine]], + group_name: str, + db_engine: Optional[Engine] = None) -> str: + db, _, db_engine = check_db_conn_engine(db=db, db_conn=None, db_engine=db_engine) + dialect = db_engine.dialect return dialect.preparer(dialect).quote_identifier(group_name) diff --git a/records_mover/db/redshift/loader.py b/records_mover/db/redshift/loader.py index 5b6d4a226..077092879 100644 --- a/records_mover/db/redshift/loader.py +++ b/records_mover/db/redshift/loader.py @@ -12,23 +12,45 @@ from .records_copy import redshift_copy_options from ...records.load_plan import RecordsLoadPlan from ..errors import CredsDoNotSupportS3Import, NoTemporaryBucketConfiguration -from typing import Optional, List, Iterator +from typing import Optional, Union, List, Iterator from ...url import BaseDirectoryUrl from botocore.credentials import Credentials from ...records.delimited import complain_on_unhandled_hints +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) class RedshiftLoader(LoaderFromRecordsDirectory): def __init__(self, - db: sqlalchemy.engine.Engine, meta: sqlalchemy.MetaData, - s3_temp_base_loc: Optional[BaseDirectoryUrl])\ - -> None: + s3_temp_base_loc: Optional[BaseDirectoryUrl], + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.db = db + self._db_conn = db_conn + self.db_engine = db_engine self.meta = meta self.s3_temp_base_loc = s3_temp_base_loc + self.conn_opened_here = False + + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug(f"Opened connection to database within {self} because none was provided.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) @contextmanager def temporary_s3_directory_loc(self) -> Iterator[BaseDirectoryUrl]: @@ -92,9 +114,9 @@ def load(self, empty_as_null=True, **redshift_options) # type: ignore logger.info(f"Starting Redshift COPY from {directory}...") - redshift_pid: int = self.db.execute(text("SELECT pg_backend_pid();")).scalar() + redshift_pid: int = self.db_conn.execute(text("SELECT pg_backend_pid();")).scalar() try: - self.db.execute(copy) + self.db_conn.execute(copy) except sqlalchemy.exc.InternalError: # Upon a load erorr, we receive: # @@ -174,3 +196,6 @@ def temporary_loadable_directory_loc(self) -> Iterator[BaseDirectoryUrl]: def has_temporary_loadable_directory_loc(self) -> bool: return self.s3_temp_base_loc is not None + + def __del__(self) -> None: + self.del_db_conn() diff --git a/records_mover/db/redshift/redshift_db_driver.py b/records_mover/db/redshift/redshift_db_driver.py index 98014232e..3846237fa 100644 --- a/records_mover/db/redshift/redshift_db_driver.py +++ b/records_mover/db/redshift/redshift_db_driver.py @@ -12,13 +12,14 @@ num_digits) from .sql import schema_sql_from_admin_views import timeout_decorator -from typing import Optional, Dict, List, Tuple +from typing import Optional, Union, Dict, List, Tuple from ...url.base import BaseDirectoryUrl from records_mover.db.quoting import quote_group_name, quote_schema_and_table from .unloader import RedshiftUnloader from ..unloader import Unloader from .loader import RedshiftLoader from ..loader import LoaderFromRecordsDirectory +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -26,22 +27,28 @@ class RedshiftDBDriver(DBDriver): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, s3_temp_base_loc: Optional[BaseDirectoryUrl] = None, **kwargs) -> None: - super().__init__(db) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) self.s3_temp_base_loc = s3_temp_base_loc self._redshift_loader =\ RedshiftLoader(db=db, + db_conn=db_conn, + db_engine=db_engine, meta=self.meta, s3_temp_base_loc=s3_temp_base_loc) self._redshift_unloader =\ RedshiftUnloader(db=db, + db_conn=db_conn, + db_engine=db_engine, table=self.table, s3_temp_base_loc=s3_temp_base_loc) def schema_sql(self, schema: str, table: str) -> str: - out = schema_sql_from_admin_views(schema, table, self.db) + out = schema_sql_from_admin_views(schema, table, None, db_conn=self.db_conn) if out is None: return super().schema_sql(schema=schema, table=table) else: @@ -52,7 +59,7 @@ def schema_sql(self, schema: str, table: str) -> str: # tables and columns filled up memory in the job. @timeout_decorator.timeout(80) def table(self, schema: str, table: str) -> Table: - with self.db.engine.connect() as conn: + with self.db_engine.connect() as conn: with conn.begin(): # The code in the Redshift SQLAlchemy driver relies on 'SET # LOCAL search_path' in order to do reflection and pull back @@ -68,19 +75,30 @@ def table(self, schema: str, table: str) -> Table: # 94723ec6437c5e5197fcf785845499e81640b167 return Table(table, self.meta, schema=schema, autoload_with=conn) - def set_grant_permissions_for_groups(self, schema_name: str, table: str, + def set_grant_permissions_for_groups(self, + schema_name: str, + table: str, groups: Dict[str, List[str]], - db: sqlalchemy.engine.Engine) -> None: - schema_and_table = quote_schema_and_table(self.db.engine, schema_name, table) + db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None + ) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + schema_and_table = quote_schema_and_table(None, schema_name, table, db_engine=db_engine) for perm_type in groups: groups_list = groups[perm_type] for group in groups_list: - group_name: str = quote_group_name(self.db.engine, group) + group_name: str = quote_group_name(None, group, db_engine=self.db_engine) if not perm_type.isalpha(): raise TypeError("Please make sure your permission types" " are an acceptable value.") perms_sql = f'GRANT {perm_type} ON TABLE {schema_and_table} TO GROUP {group_name}' - db.execute(perms_sql) + if db_conn: + db_conn.execute(perms_sql) + else: + with db_engine.connect() as conn: + conn.execute(perms_sql) return None def supports_time_type(self) -> bool: diff --git a/records_mover/db/redshift/sql.py b/records_mover/db/redshift/sql.py index 0bc62e035..63e25789d 100644 --- a/records_mover/db/redshift/sql.py +++ b/records_mover/db/redshift/sql.py @@ -1,7 +1,10 @@ +# flake8: noqa + import sqlalchemy import logging -from typing import Optional +from typing import Union, Optional from sqlalchemy.sql import text +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -19,11 +22,15 @@ def schema_sql_from_admin_views(schema: str, table: str, - db: sqlalchemy.engine.Engine)\ + db: Optional[Union[sqlalchemy.engine.Engine, + sqlalchemy.engine.Connection]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None)\ -> Optional[str]: # The default behavior in current sqlalchemy driver seems to # map "DOUBLE PRECISION" to "DOUBLE_PRECISION", so lean back # on admin views for now: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) sql = text(""" SELECT ddl FROM admin.v_generate_tbl_ddl @@ -32,10 +39,9 @@ def schema_sql_from_admin_views(schema: str, """) out: Optional[str] try: - with db.connect() as connection: - with connection.begin(): - result = connection.execute(sql, {'schema_name': schema, - 'table_name': table}).fetchall() + result = db_conn.execute(sql, # type: ignore[union-attr] + {'schema_name': schema, + 'table_name': table}).fetchall() if len(result) == 0: out = None else: diff --git a/records_mover/db/redshift/unloader.py b/records_mover/db/redshift/unloader.py index ee38fc34f..e6963ba9f 100644 --- a/records_mover/db/redshift/unloader.py +++ b/records_mover/db/redshift/unloader.py @@ -12,12 +12,13 @@ from ...records.records_format import ( BaseRecordsFormat, DelimitedRecordsFormat, ParquetRecordsFormat ) -from typing import Callable, Optional, List, Iterator +from typing import Union, Callable, Optional, List, Iterator from ...url.base import BaseDirectoryUrl from botocore.credentials import Credentials from ..errors import CredsDoNotSupportS3Export, NoTemporaryBucketConfiguration from ...records.delimited import complain_on_unhandled_hints from ..unloader import Unloader +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -25,11 +26,14 @@ class RedshiftUnloader(Unloader): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Engine, sqlalchemy.engine.Connection]], table: Callable[[str, str], Table], s3_temp_base_loc: Optional[BaseDirectoryUrl], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs) -> None: - super().__init__(db=db) + check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) self.table = table self.s3_temp_base_loc = s3_temp_base_loc @@ -80,17 +84,17 @@ def unload_to_s3_directory(self, # register_secret(aws_creds.token) register_secret(aws_creds.secret_key) - select = text(f"SELECT * FROM {quote_schema_and_table(self.db, schema, table)}") + select = text( + "SELECT * FROM " + f"{quote_schema_and_table(None, schema, table, db_engine=self.db_engine)}") unload = UnloadFromSelect(select=select, access_key_id=aws_creds.access_key, secret_access_key=aws_creds.secret_key, session_token=aws_creds.token, manifest=True, unload_location=directory.loc.url, **redshift_options) try: - with self.db.connect() as connection: - with connection.begin(): - connection.execute(unload) - out = connection.execute(text("SELECT pg_last_unload_count()")) + self.db_conn.execute(unload) + out = self.db_conn.execute(text("SELECT pg_last_unload_count()")) rows: Optional[int] = out.scalar() assert rows is not None logger.info(f"Just unloaded {rows} rows") diff --git a/records_mover/db/unloader.py b/records_mover/db/unloader.py index c1e971f60..1a6fea9aa 100644 --- a/records_mover/db/unloader.py +++ b/records_mover/db/unloader.py @@ -4,17 +4,44 @@ from ..records.records_format import BaseRecordsFormat from ..records.records_directory import RecordsDirectory from records_mover.url.base import BaseDirectoryUrl -from typing import List, Optional, Iterator +from typing import Union, List, Optional, Iterator from abc import ABCMeta, abstractmethod import sqlalchemy +from ..check_db_conn_engine import check_db_conn_engine +import logging + + +logger = logging.getLogger(__name__) class Unloader(metaclass=ABCMeta): def __init__(self, - db: sqlalchemy.engine.Engine) -> None: + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.db = db + self._db_conn = db_conn + self.db_engine = db_engine + self.conn_opened_here = False self.meta = MetaData() + def get_db_conn(self) -> sqlalchemy.engine.Connection: + if self._db_conn is None: + self._db_conn = self.db_engine.connect() + self.conn_opened_here = True + logger.debug("Opened connection to database which will be closed inside this uploader.") + return self._db_conn + + def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None: + self._db_conn = db_conn + + def del_db_conn(self) -> None: + if self.conn_opened_here: + self.db_conn.close() + + db_conn = property(get_db_conn, set_db_conn, del_db_conn) + @abstractmethod def unload(self, schema: str, @@ -65,3 +92,6 @@ def best_records_format(self) -> Optional[BaseRecordsFormat]: if len(supported_formats) == 0: return None return supported_formats[0] + + def __del__(self) -> None: + self.del_db_conn() diff --git a/records_mover/db/vertica/export_sql.py b/records_mover/db/vertica/export_sql.py index f58947b76..9ab4824d8 100644 --- a/records_mover/db/vertica/export_sql.py +++ b/records_mover/db/vertica/export_sql.py @@ -21,7 +21,7 @@ def vertica_export_sql(db_engine: Engine, """ def quote(value: str) -> str: - return quote_value(db_engine, value) + return quote_value(None, value, db_engine=db_engine) params_data = { "url": quote(s3_url + 'records.csv'), @@ -32,7 +32,7 @@ def quote(value: str) -> str: } params = ", ".join([f"{key}={value}" for key, value in params_data.items()]) - schema_and_table = quote_schema_and_table(db_engine, schema, table) + schema_and_table = quote_schema_and_table(None, schema, table, db_engine=db_engine) sql = template.format(params=params, schema_and_table=schema_and_table) diff --git a/records_mover/db/vertica/import_sql.py b/records_mover/db/vertica/import_sql.py index da7662043..b4005dd6f 100644 --- a/records_mover/db/vertica/import_sql.py +++ b/records_mover/db/vertica/import_sql.py @@ -55,14 +55,14 @@ def quote(value: str) -> str: return quote_value(db_engine, value) if rejected_data_table is not None and rejected_data_schema is not None: - rejected_target = quote_schema_and_table(db_engine, - rejected_data_schema, rejected_data_table) + rejected_target = quote_schema_and_table(None, rejected_data_schema, rejected_data_table, + db_engine=db_engine) rejected_data = f"REJECTED DATA AS TABLE {rejected_target}" else: rejected_data = '' import_sql = import_sql_template.format( - schema_and_table=quote_schema_and_table(db_engine, schema, table), + schema_and_table=quote_schema_and_table(None, schema, table, db_engine=db_engine), # https://forum.vertica.com/discussion/238556/reading-gzip-files-from-s3-into-vertica gzip='GZIP' if gzip else '', # The capital E in the next line specifies a string literal diff --git a/records_mover/db/vertica/loader.py b/records_mover/db/vertica/loader.py index e4cf93189..c93e2ad17 100644 --- a/records_mover/db/vertica/loader.py +++ b/records_mover/db/vertica/loader.py @@ -10,8 +10,9 @@ from ...records.records_format import DelimitedRecordsFormat, BaseRecordsFormat from ...records.processing_instructions import ProcessingInstructions from ..loader import LoaderFromFileobj -from typing import IO, List, Type +from typing import IO, Union, List, Type, Optional import logging +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -19,9 +20,14 @@ class VerticaLoader(LoaderFromFileobj): def __init__(self, url_resolver: UrlResolver, - db: sqlalchemy.engine.Engine) -> None: + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) self.url_resolver = url_resolver self.db = db + self.db_conn = db_conn + self.db_engine = db_engine def load_from_fileobj(self, schema: str, @@ -41,11 +47,11 @@ def load_from_fileobj(self, # vertica_options isn't yet a TypedDict that matches the # vertica_import_sql options, so suppress type checking - import_sql = vertica_import_sql(db_engine=self.db.engine, table=table, + import_sql = vertica_import_sql(db_engine=self.db_engine, table=table, schema=schema, **vertica_options) # type: ignore rawconn = None try: - rawconn = self.db.engine.raw_connection() + rawconn = self.db_engine.raw_connection() cursor = rawconn.cursor() logger.info(import_sql) if isinstance(fileobj, urllib.response.addinfourl): diff --git a/records_mover/db/vertica/unloader.py b/records_mover/db/vertica/unloader.py index 4e2f5f53b..7c2146f60 100644 --- a/records_mover/db/vertica/unloader.py +++ b/records_mover/db/vertica/unloader.py @@ -1,3 +1,4 @@ +from ...check_db_conn_engine import check_db_conn_engine from records_mover.db.quoting import quote_value import sqlalchemy from sqlalchemy import text @@ -13,7 +14,7 @@ from ...records.delimited import complain_on_unhandled_hints from ..unloader import Unloader import logging -from typing import Iterator, Optional, List, TYPE_CHECKING +from typing import Iterator, Optional, Union, List, TYPE_CHECKING if TYPE_CHECKING: from botocore.credentials import Credentials # noqa @@ -23,9 +24,12 @@ class VerticaUnloader(Unloader): def __init__(self, - db: sqlalchemy.engine.Engine, - s3_temp_base_loc: Optional[BaseDirectoryUrl]) -> None: - super().__init__(db=db) + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], + s3_temp_base_loc: Optional[BaseDirectoryUrl], + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None: + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) self.s3_temp_base_loc = s3_temp_base_loc @contextmanager @@ -41,8 +45,8 @@ def aws_creds_sql(self, aws_id: str, aws_secret: str) -> str: return """ ALTER SESSION SET UDPARAMETER FOR awslib aws_id={aws_id}; ALTER SESSION SET UDPARAMETER FOR awslib aws_secret={aws_secret}; - """.format(aws_id=quote_value(self.db.engine, aws_id), - aws_secret=quote_value(self.db.engine, aws_secret)) + """.format(aws_id=quote_value(None, aws_id, db_engine=self.db_engine), + aws_secret=quote_value(None, aws_secret, db_engine=self.db_engine)) def unload(self, schema: str, @@ -70,13 +74,11 @@ def s3_temp_bucket_available(self) -> bool: return self.s3_temp_base_loc is not None def s3_export_available(self) -> bool: - with self.db.connect() as connection: - out = connection.execute(text("SELECT lib_name " - "from user_libraries " - "where lib_name = 'awslib'")) - available = len(list(out.fetchall())) == 1 - if not available: - logger.info("Not attempting S3 export - no access to awslib in Vertica") + out = self.db_conn.execute( + text("SELECT lib_name from user_libraries where lib_name = 'awslib'")) + available = len(list(out.fetchall())) == 1 + if not available: + logger.info("Not attempting S3 export - no access to awslib in Vertica") return available def unload_to_s3_directory(self, @@ -101,9 +103,7 @@ def unload_to_s3_directory(self, processing_instructions = unload_plan.processing_instructions try: s3_sql = self.aws_creds_sql(aws_creds.access_key, aws_creds.secret_key) - with self.db.connect() as connection: - with connection.begin(): - connection.execute(text(s3_sql)) + self.db_conn.execute(text(s3_sql)) except sqlalchemy.exc.ProgrammingError as e: raise DatabaseDoesNotSupportS3Export(str(e)) @@ -112,15 +112,13 @@ def unload_to_s3_directory(self, unhandled_hints, unload_plan.records_format.hints) - export_sql = vertica_export_sql(db_engine=self.db.engine, + export_sql = vertica_export_sql(db_engine=self.db_engine, table=table, schema=schema, s3_url=directory.loc.url, **vertica_options) logger.info(export_sql) - with self.db.connect() as connection: - with connection.begin(): - export_result = connection.execute(text(export_sql)).fetchall() + export_result = self.db_conn.execute(text(export_sql)).fetchall() directory.save_preliminary_manifest() export_count = 0 for record in export_result: diff --git a/records_mover/db/vertica/vertica_db_driver.py b/records_mover/db/vertica/vertica_db_driver.py index 478d0b6ed..bc61b2548 100644 --- a/records_mover/db/vertica/vertica_db_driver.py +++ b/records_mover/db/vertica/vertica_db_driver.py @@ -4,7 +4,7 @@ from records_mover.db.quoting import quote_schema_and_table from sqlalchemy.schema import Table, Column import logging -from typing import Optional, Tuple +from typing import Optional, Union, Tuple from ...url.resolver import UrlResolver from ...url.base import BaseDirectoryUrl from .loader import VerticaLoader @@ -14,6 +14,7 @@ from ...utils.limits import (INT64_MIN, INT64_MAX, FLOAT64_SIGNIFICAND_BITS, num_digits) +from ...check_db_conn_engine import check_db_conn_engine logger = logging.getLogger(__name__) @@ -21,13 +22,18 @@ class VerticaDBDriver(DBDriver): def __init__(self, - db: sqlalchemy.engine.Engine, + db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]], url_resolver: UrlResolver, s3_temp_base_loc: Optional[BaseDirectoryUrl] = None, + db_conn: Optional[sqlalchemy.engine.Connection] = None, + db_engine: Optional[sqlalchemy.engine.Engine] = None, **kwargs: object) -> None: - super().__init__(db) - self._vertica_loader = VerticaLoader(url_resolver=url_resolver, db=self.db) - self._vertica_unloader = VerticaUnloader(s3_temp_base_loc=s3_temp_base_loc, db=db) + db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine) + super().__init__(db=db, db_conn=db_conn, db_engine=db_engine) + self._vertica_loader = VerticaLoader(url_resolver=url_resolver, db=db, db_conn=db_conn, + db_engine=db_engine) + self._vertica_unloader = VerticaUnloader(s3_temp_base_loc=s3_temp_base_loc, db=db, + db_conn=db_conn, db_engine=db_engine) self.url_resolver = url_resolver def loader(self) -> Optional[LoaderFromRecordsDirectory]: @@ -41,16 +47,17 @@ def unloader(self) -> Optional[Unloader]: def has_table(self, schema: str, table: str) -> bool: try: - sql = f"SELECT 1 from {quote_schema_and_table(self.db, schema, table)} limit 0;" - self.db.execute(text(sql)) + sql = ("SELECT 1 " + f"from {quote_schema_and_table(None, schema, table, db_engine=self.db_engine)} " + "limit 0;") + self.db_conn.execute(text(sql)) return True except sqlalchemy.exc.ProgrammingError: return False def schema_sql(self, schema: str, table: str) -> str: - sql = text("SELECT EXPORT_OBJECTS('', :schema_and_table, false)") - with self.db.connect() as connection: - result = connection.execute(sql, {'schema_and_table': f"{schema}.{table}"}).fetchall() + sql = text("SELECT EXPORT_OBJECTS('', :schema_and_table , false)") + result = self.db_conn.execute(sql, {'schema_and_table': f"{schema}.{table}"}).fetchall() if len(result) == 1: return result[0].EXPORT_OBJECTS @@ -68,16 +75,14 @@ def table(self, # https://docs.sqlalchemy.org/en/latest/core/metadata.html#sqlalchemy.schema.Column # https://docs.sqlalchemy.org/en/latest/core/reflection.html# # sqlalchemy.engine.reflection.Inspector.get_columns - with self.db_engine.connect() as connection: - with connection.begin(): - columns = [Column(colinfo['name'], - type_=colinfo['type'], - nullable=colinfo['nullable'], - default=colinfo['default'], - **colinfo.get('attrs', {})) - for colinfo in self.db_engine.dialect.get_columns(connection, - table, - schema=schema)] + columns = [Column(colinfo['name'], + type_=colinfo['type'], + nullable=colinfo['nullable'], + default=colinfo['default'], + **colinfo.get('attrs', {})) + for colinfo in self.db_engine.dialect.get_columns(self.db_conn, + table, + schema=schema)] t = Table(table, self.meta, schema=schema, *columns) return t diff --git a/records_mover/records/delimited/sniff.py b/records_mover/records/delimited/sniff.py index e52d5ceb8..841c296d0 100644 --- a/records_mover/records/delimited/sniff.py +++ b/records_mover/records/delimited/sniff.py @@ -1,3 +1,5 @@ +# flake8: noqa + import chardet from contextlib import contextmanager from . import PartialRecordsHints @@ -268,10 +270,10 @@ def sniff_hints(fileobj: IO[bytes], # out: PartialRecordsHints = { 'compression': compression_hint, - **pandas_inferred_hints, # type: ignore + **pandas_inferred_hints, **python_inferred_hints, 'encoding': final_encoding_hint, - **other_inferred_csv_hints, + **other_inferred_csv_hints, # type: ignore[typeddict-item] **initial_hints } logger.info(f"Inferred hints from combined sources: {out}") diff --git a/records_mover/records/mover.py b/records_mover/records/mover.py index 5a4aa71c6..8829da9bc 100644 --- a/records_mover/records/mover.py +++ b/records_mover/records/mover.py @@ -117,7 +117,8 @@ def move(records_source: RecordsSource, elif isinstance(records_source, SupportsToFileobjsSource): # Incompatible types in assignment (expression has type "Optional[Any]", # variable has type "BaseRecordsFormat") - target_records_format: BaseRecordsFormat = getattr(records_target, "records_format", None) # type: ignore + target_records_format: BaseRecordsFormat \ + = getattr(records_target, "records_format", None) # type: ignore logger.info(f"Mover: copying from {records_source} to {records_target} " f"by first writing {records_source} to {target_records_format} " "records format (if easy to rewrite)...") diff --git a/records_mover/records/prep.py b/records_mover/records/prep.py index 0190a93ae..d3a5caecd 100644 --- a/records_mover/records/prep.py +++ b/records_mover/records/prep.py @@ -1,11 +1,11 @@ from typing import Optional from sqlalchemy.engine import Connection -from sqlalchemy import text from records_mover.db.quoting import quote_schema_and_table from records_mover.records.existing_table_handling import ExistingTableHandling from records_mover.db import DBDriver from records_mover.records.table import TargetTableDetails import logging +from sqlalchemy import text logger = logging.getLogger(__name__) @@ -14,17 +14,19 @@ class TablePrep: def __init__(self, target_table_details: TargetTableDetails) -> None: self.tbl = target_table_details - def add_permissions(self, driver: DBDriver) -> None: - schema_and_table: str = quote_schema_and_table(driver.db, + def add_permissions(self, conn: Connection, driver: DBDriver) -> None: + schema_and_table: str = quote_schema_and_table(None, self.tbl.schema_name, - self.tbl.table_name) + self.tbl.table_name, + db_engine=driver.db_engine) if self.tbl.add_group_perms_for is not None: logger.info(f"Adding permissions for {schema_and_table} " f"to group {self.tbl.add_group_perms_for}") driver.set_grant_permissions_for_groups(self.tbl.schema_name, self.tbl.table_name, self.tbl.add_group_perms_for, - driver.db_engine) + None, + db_conn=conn) if self.tbl.add_user_perms_for is not None: logger.info(f"Adding permissions for {schema_and_table} " f"to {self.tbl.add_user_perms_for}") @@ -32,13 +34,14 @@ def add_permissions(self, driver: DBDriver) -> None: set_grant_permissions_for_users(self.tbl.schema_name, self.tbl.table_name, self.tbl.add_user_perms_for, - driver.db_engine) + None, + db_conn=conn) def create_table(self, schema_sql: str, conn: Connection, driver: DBDriver) -> None: logger.info('Creating table...') - conn.exec_driver_sql(schema_sql) # type: ignore + conn.execute(text(schema_sql)) logger.info(f"Just ran {schema_sql}") - self.add_permissions(driver) + self.add_permissions(conn, driver) logger.info("Table prepped") def prep_table_for_load(self, @@ -46,28 +49,31 @@ def prep_table_for_load(self, existing_table_handling: ExistingTableHandling, driver: DBDriver) -> None: logger.info("Looking for existing table..") - db = driver.db + db_engine = driver.db_engine + db_conn = driver.db_conn if driver.has_table(table=self.tbl.table_name, schema=self.tbl.schema_name): logger.info("Table already exists.") how_to_prep = existing_table_handling - schema_and_table: str = quote_schema_and_table(db, + schema_and_table: str = quote_schema_and_table(None, self.tbl.schema_name, - self.tbl.table_name) + self.tbl.table_name, + db_engine=db_engine,) if (how_to_prep == ExistingTableHandling.TRUNCATE_AND_OVERWRITE): logger.info("Truncating...") - db.execute(text(f"TRUNCATE TABLE {schema_and_table}")) + db_conn.execute(text(f"TRUNCATE TABLE {schema_and_table}")) logger.info("Truncated.") elif (how_to_prep == ExistingTableHandling.DELETE_AND_OVERWRITE): logger.info("Deleting rows...") - db.execute(text(f"DELETE FROM {schema_and_table} WHERE true")) + db_conn.execute(text(f"DELETE FROM {schema_and_table} WHERE true")) logger.info("Deleted") elif (how_to_prep == ExistingTableHandling.DROP_AND_RECREATE): - with db.engine.connect() as conn: + with db_engine.connect() as conn: with conn.begin(): + logger.info(f"The connection object is: {conn}") logger.info("Dropping and recreating...") drop_table_sql = f"DROP TABLE {schema_and_table}" - conn.execute(drop_table_sql) + conn.execute(text(drop_table_sql)) logger.info(f"Just ran {drop_table_sql}") self.create_table(schema_sql, conn, driver) elif (how_to_prep == ExistingTableHandling.APPEND): @@ -75,7 +81,7 @@ def prep_table_for_load(self, else: raise ValueError(f"Don't know how to handle {how_to_prep}") else: - with db.engine.connect() as conn: + with db_engine.connect() as conn: with conn.begin(): self.create_table(schema_sql, conn, driver) diff --git a/records_mover/records/prep_and_load.py b/records_mover/records/prep_and_load.py index 3f099a9af..eb01f5cbd 100644 --- a/records_mover/records/prep_and_load.py +++ b/records_mover/records/prep_and_load.py @@ -16,24 +16,24 @@ def prep_and_load(tbl: TargetTableDetails, load_exception_type: Type[Exception], reset_before_reload: Callable[[], None] = lambda: None) -> MoveResult: logger.info("Connecting to database...") - with tbl.db_engine.begin() as db: - driver = tbl.db_driver(db) + with tbl.db_engine.begin() as db_conn: + driver = tbl.db_driver(db=None, db_conn=db_conn) prep.prep(schema_sql=schema_sql, driver=driver) - with tbl.db_engine.begin() as db: + with tbl.db_engine.begin() as db_conn: # This second transaction ensures the table has been created # before non-transactional statements like Redshift's COPY # take place. Otherwise you'll get an error like: # # Cannot COPY into nonexistent table - driver = tbl.db_driver(db) + driver = tbl.db_driver(db=None, db_conn=db_conn) try: import_count = load(driver) except load_exception_type: if not tbl.drop_and_recreate_on_load_error: raise reset_before_reload() - with tbl.db_engine.begin() as db: - driver = tbl.db_driver(db) + with tbl.db_engine.begin() as db_conn: + driver = tbl.db_driver(db=None, db_conn=db_conn) prep.prep(schema_sql=schema_sql, driver=driver, existing_table_handling=ExistingTableHandling.DROP_AND_RECREATE) diff --git a/records_mover/records/records.py b/records_mover/records/records.py index a3408ca8e..190d03a25 100644 --- a/records_mover/records/records.py +++ b/records_mover/records/records.py @@ -4,10 +4,10 @@ from .targets import RecordsTargets from .mover import move from enum import Enum -from typing import Callable, Union, TYPE_CHECKING +from typing import Callable, Union, Optional, TYPE_CHECKING if TYPE_CHECKING: from sqlalchemy.engine import Engine, Connection # noqa - from ..db import DBDriver # noqa + from ..db import DBDriver # noqa from records_mover import Session # noqa @@ -60,7 +60,9 @@ class Records: "Alias of :meth:`records_mover.records.move`" def __init__(self, - db_driver: Union[Callable[['Engine'], 'DBDriver'], + db_driver: Union[Callable[[Union['Engine', 'Connection', None], + Optional['Connection'], + Optional['Engine']], 'DBDriver'], PleaseInfer] = PleaseInfer.token, url_resolver: Union[UrlResolver, PleaseInfer] = PleaseInfer.token, session: Union['Session', PleaseInfer] = PleaseInfer.token) -> None: diff --git a/records_mover/records/schema/field/sqlalchemy.py b/records_mover/records/schema/field/sqlalchemy.py index 0f4b61977..444594c47 100644 --- a/records_mover/records/schema/field/sqlalchemy.py +++ b/records_mover/records/schema/field/sqlalchemy.py @@ -107,7 +107,7 @@ def field_from_sqlalchemy_column(column: Column, driver=driver) representations = { 'origin': RecordsSchemaFieldRepresentation. - from_sqlalchemy_column(column, driver.db.dialect, + from_sqlalchemy_column(column, driver.db_engine.dialect, rep_type) } diff --git a/records_mover/records/schema/schema/known_representation.py b/records_mover/records/schema/schema/known_representation.py index 72f86ce39..52aa802b8 100644 --- a/records_mover/records/schema/schema/known_representation.py +++ b/records_mover/records/schema/schema/known_representation.py @@ -60,7 +60,7 @@ def from_dataframe(df: 'DataFrame', def from_db_driver(driver: 'DBDriver', schema_name: str, table_name: str) -> 'RecordsSchemaSqlKnownRepresentation': - type = f"sql/{driver.db.dialect.name}" + type = f"sql/{driver.db_engine.dialect.name}" ddl = driver.schema_sql(schema_name, table_name) return RecordsSchemaSqlKnownRepresentation(type=type, table_ddl=ddl) diff --git a/records_mover/records/sources/factory.py b/records_mover/records/sources/factory.py index 19cf5326f..8228f4ece 100644 --- a/records_mover/records/sources/factory.py +++ b/records_mover/records/sources/factory.py @@ -1,3 +1,5 @@ +# flake8: noqa + import pathlib from ..records_format import BaseRecordsFormat from ..schema import RecordsSchema @@ -15,7 +17,7 @@ if TYPE_CHECKING: # see the 'gsheets' extras_require option in setup.py - needed for this! import google.auth.credentials # noqa - from sqlalchemy.engine import Engine # noqa + from sqlalchemy.engine import Engine, Connection # noqa from ...db import DBDriver # noqa from .google_sheets import GoogleSheetsRecordsSource # noqa ditto # with pandas, which an optional addition for clients of this @@ -50,7 +52,9 @@ class RecordsSources(object): """ def __init__(self, - db_driver: Callable[['Engine'], 'DBDriver'], + db_driver: Callable[[Optional[Union['Engine', 'Connection']], + Optional['Connection'], + Optional['Engine']], 'DBDriver'], url_resolver: UrlResolver) -> None: self.db_driver = db_driver self.url_resolver = url_resolver @@ -158,18 +162,22 @@ def data_url(self, def table(self, db_engine: 'Engine', schema_name: str, - table_name: str) -> 'TableRecordsSource': + table_name: str, + db_conn: Optional['Connection'] = None) -> 'TableRecordsSource': """Represents a SQLALchemy-accessible database table as as a source. :param db_engine: SQLAlchemy database engine to pull data from. :param schema_name: Schema name of a table to get data from. :param table_name: Table name of a table to get data from. + :param db_conn: SQLAlchemy database connection to use to pull data from. """ from .table import TableRecordsSource # noqa - return TableRecordsSource(schema_name=schema_name, - table_name=table_name, - url_resolver=self.url_resolver, - driver=self.db_driver(db_engine)) + return TableRecordsSource( + schema_name=schema_name, + table_name=table_name, + url_resolver=self.url_resolver, + driver=self.db_driver(None, db_engine=db_engine, # type: ignore[call-arg] + db_conn=db_conn)) def directory_from_url(self, url: str, diff --git a/records_mover/records/sources/table.py b/records_mover/records/sources/table.py index 4c2c14b8c..4d7826e95 100644 --- a/records_mover/records/sources/table.py +++ b/records_mover/records/sources/table.py @@ -75,15 +75,11 @@ def to_dataframes_source(self, from .dataframes import DataframesRecordsSource # noqa import pandas - db = self.driver.db + db_conn = self.driver.db_conn + db_engine = self.driver.db_engine records_schema = self.pull_records_schema() - if isinstance(db, Engine): - connection = db.connect() - columns = db.dialect.get_columns(connection, self.table_name, schema=self.schema_name) - connection.close() - else: - columns = db.dialect.get_columns(db, self.table_name, schema=self.schema_name) + columns = db_engine.dialect.get_columns(db_conn, self.table_name, schema=self.schema_name) num_columns = len(columns) if num_columns == 0: @@ -92,20 +88,19 @@ def to_dataframes_source(self, chunksize = int(entries_per_chunk / num_columns) logger.info(f"Exporting in chunks of up to {chunksize} rows by {num_columns} columns") - quoted_table = quote_schema_and_table(db, self.schema_name, self.table_name) - with db.connect() as connection: - with connection.begin(): - chunks: Generator['DataFrame', None, None] = \ - pandas.read_sql(text(f"SELECT * FROM {quoted_table}"), - con=connection, - chunksize=chunksize) - try: - yield DataframesRecordsSource(dfs=self.with_cast_dataframe_types(records_schema, - chunks), - records_schema=records_schema, - processing_instructions=processing_instructions) - finally: - chunks.close() + quoted_table = quote_schema_and_table(None, self.schema_name, + self.table_name, db_engine=db_engine,) + chunks: Generator['DataFrame', None, None] = \ + pandas.read_sql(text(f"SELECT * FROM {quoted_table}"), + con=db_conn, + chunksize=chunksize) + try: + yield DataframesRecordsSource(dfs=self.with_cast_dataframe_types(records_schema, + chunks), + records_schema=records_schema, + processing_instructions=processing_instructions) + finally: + chunks.close() def with_cast_dataframe_types(self, records_schema: RecordsSchema, diff --git a/records_mover/records/table.py b/records_mover/records/table.py index c7efb703d..c94bb6958 100644 --- a/records_mover/records/table.py +++ b/records_mover/records/table.py @@ -1,6 +1,6 @@ from abc import ABCMeta -from sqlalchemy.engine import Engine -from typing import Optional, Dict, List +from sqlalchemy.engine import Engine, Connection +from typing import Union, Optional, Dict, List from records_mover.records.existing_table_handling import ExistingTableHandling from records_mover.db import DBDriver import logging @@ -24,5 +24,8 @@ class TargetTableDetails(metaclass=ABCMeta): # # https://github.com/python/mypy/issues/5485 # @abstractmethod - def db_driver(self, db: Engine) -> DBDriver: # type: ignore + def db_driver(self, # type: ignore + db: Optional[Union[Engine, Connection]], + db_engine: Optional[Engine] = None, + db_conn: Optional[Connection] = None) -> DBDriver: ... diff --git a/records_mover/records/targets/factory.py b/records_mover/records/targets/factory.py index 5b83adb2f..2d1872c5f 100644 --- a/records_mover/records/targets/factory.py +++ b/records_mover/records/targets/factory.py @@ -4,7 +4,7 @@ from .fileobj import FileobjTarget from .directory_from_url import DirectoryFromUrlRecordsTarget from .data_url import DataUrlTarget -from typing import Callable, Optional, Dict, List, IO, TYPE_CHECKING +from typing import Callable, Optional, Union, Dict, List, IO, TYPE_CHECKING from ..existing_table_handling import ExistingTableHandling if TYPE_CHECKING: # see the 'gsheets' extras_require option in setup.py - needed for this! @@ -43,7 +43,9 @@ class RecordsTargets(object): def __init__(self, url_resolver: UrlResolver, - db_driver: Callable[['Engine'], 'DBDriver']) -> None: + db_driver: Callable[[Optional[Union['Engine', 'Connection']], + Optional['Connection'], + Optional['Engine']], 'DBDriver']) -> None: self.url_resolver = url_resolver self.db_driver = db_driver @@ -73,7 +75,8 @@ def table(self, ExistingTableHandling.DELETE_AND_OVERWRITE, drop_and_recreate_on_load_error: bool = False, add_user_perms_for: Optional[Dict[str, List[str]]] = None, - add_group_perms_for: Optional[Dict[str, List[str]]] = None) -> \ + add_group_perms_for: Optional[Dict[str, List[str]]] = None, + db_conn: Optional['Connection'] = None) -> \ 'TableRecordsTarget': """Represents a SQLALchemy-accessible database table as as a target. @@ -97,6 +100,9 @@ def table(self, :param add_group_perms_for: If specified, a table's permissions will be set for the specified group. Format should be like {'all': ['group1', 'group2'], 'select': ['group3', 'group4']} + + :param db_conn: SQLAlchemy database connection to write data to. If not specified, one + will be created from the db_engine. """ from .table import TableRecordsTarget # noqa return TableRecordsTarget(schema_name=schema_name, @@ -106,7 +112,8 @@ def table(self, existing_table_handling=existing_table_handling, drop_and_recreate_on_load_error=drop_and_recreate_on_load_error, add_user_perms_for=add_user_perms_for, - add_group_perms_for=add_group_perms_for) + add_group_perms_for=add_group_perms_for, + db_conn=db_conn) def google_sheet(self, spreadsheet_id: str, diff --git a/records_mover/records/targets/spectrum.py b/records_mover/records/targets/spectrum.py index 2039319fa..7ea1a1dac 100644 --- a/records_mover/records/targets/spectrum.py +++ b/records_mover/records/targets/spectrum.py @@ -1,14 +1,16 @@ +# flake8: noqa + from .base import SupportsRecordsDirectory from records_mover.db.quoting import quote_schema_and_table from ...db import DBDriver from ...url.resolver import UrlResolver from ...url import BaseDirectoryUrl -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection from ..records_directory import RecordsDirectory from ..records_format import ParquetRecordsFormat from sqlalchemy.schema import CreateTable, MetaData, Table from ..existing_table_handling import ExistingTableHandling -from typing import Optional, Callable +from typing import Optional, Callable, Union import logging import sqlalchemy from sqlalchemy import text @@ -22,14 +24,19 @@ def __init__(self, schema_name: str, table_name: str, db_engine: Engine, - db_driver: Callable[[Engine], DBDriver], + db_driver: Callable[[Optional[Union['Engine', 'Connection']], + Optional[Connection], + Optional[Engine]], DBDriver], url_resolver: UrlResolver, spectrum_base_url: Optional[str], spectrum_rdir_url: Optional[str], existing_table_handling: ExistingTableHandling = ExistingTableHandling.TRUNCATE_AND_OVERWRITE) -> None: - self.db = db_engine - self.driver = db_driver(db_engine) + self.db = None + self.db_engine = db_engine + self.driver = db_driver(db=None, # type: ignore[call-arg] + db_conn=None, + db_engine=db_engine) self.schema_name = schema_name self.table_name = table_name self.url_resolver = url_resolver @@ -71,11 +78,12 @@ def prep_bucket(self) -> None: if self.existing_table_handling == ExistingTableHandling.DELETE_AND_OVERWRITE: logger.warning('Redshift Spectrum does not support transactional delete.') if self.existing_table_handling == ExistingTableHandling.DROP_AND_RECREATE: - schema_and_table: str = quote_schema_and_table(self.db, + schema_and_table: str = quote_schema_and_table(None, self.schema_name, - self.table_name) + self.table_name, + db_engine=self.db_engine) logger.info(f"Dropping external table {schema_and_table}...") - with self.db.connect() as cursor: + with self.db_engine.connect() as cursor: # See below note about fix from Spectrify cursor.execution_options(isolation_level='AUTOCOMMIT') cursor.execute(text(f"DROP TABLE IF EXISTS {schema_and_table}")) diff --git a/records_mover/records/targets/table/move_from_dataframes_source.py b/records_mover/records/targets/table/move_from_dataframes_source.py index a288eded1..6e200751a 100644 --- a/records_mover/records/targets/table/move_from_dataframes_source.py +++ b/records_mover/records/targets/table/move_from_dataframes_source.py @@ -85,7 +85,7 @@ def load(self, driver: DBDriver) -> int: return rows_loaded def move_from_dataframes_source_via_insert(self) -> MoveResult: - driver = self.tbl.db_driver(self.tbl.db_engine) + driver = self.tbl.db_driver(db=None, db_engine=self.tbl.db_engine) schema_sql = self.records_schema.to_schema_sql(driver, self.tbl.schema_name, self.tbl.table_name) diff --git a/records_mover/records/targets/table/move_from_fileobjs_source.py b/records_mover/records/targets/table/move_from_fileobjs_source.py index 82c595686..abb5c3fbf 100644 --- a/records_mover/records/targets/table/move_from_fileobjs_source.py +++ b/records_mover/records/targets/table/move_from_fileobjs_source.py @@ -53,8 +53,8 @@ def reset_before_reload(self) -> None: self.fileobj.seek(0) def move(self) -> MoveResult: - with self.tbl.db_engine.begin() as db: - driver = self.tbl.db_driver(db) + with self.tbl.db_engine.begin() as db_conn: + driver = self.tbl.db_driver(None, db_conn=db_conn) schema_obj = self.fileobjs_source.records_schema schema_sql = self.schema_sql_for_load(schema_obj, self.records_format, driver) loader_from_fileobj = driver.loader_from_fileobj() diff --git a/records_mover/records/targets/table/move_from_records_directory.py b/records_mover/records/targets/table/move_from_records_directory.py index f4e8034f7..ef9755c9e 100644 --- a/records_mover/records/targets/table/move_from_records_directory.py +++ b/records_mover/records/targets/table/move_from_records_directory.py @@ -61,8 +61,8 @@ def load(self, driver: DBDriver) -> Optional[int]: def move(self) -> MoveResult: logger.info("Connecting to database...") - with self.tbl.db_engine.begin() as db: - driver = self.tbl.db_driver(db) + with self.tbl.db_engine.begin() as db_conn: + driver = self.tbl.db_driver(None, db_conn=db_conn) loader = driver.loader() # If we've gotten here, .can_move_from_format() has # returned True in the move() method, and that can only happen diff --git a/records_mover/records/targets/table/move_from_temp_loc_after_filling_it.py b/records_mover/records/targets/table/move_from_temp_loc_after_filling_it.py index fc725d96a..be7b51875 100644 --- a/records_mover/records/targets/table/move_from_temp_loc_after_filling_it.py +++ b/records_mover/records/targets/table/move_from_temp_loc_after_filling_it.py @@ -27,7 +27,7 @@ def __init__(self, @contextmanager def temporary_loadable_directory_loc(self) -> Iterator[BaseDirectoryUrl]: - driver = self.tbl.db_driver(self.tbl.db_engine) + driver = self.tbl.db_driver(db=None, db_engine=self.tbl.db_engine) loader = driver.loader() # This will only be reached in move() if # Source#has_compatible_format(records_target) returns true, diff --git a/records_mover/records/targets/table/target.py b/records_mover/records/targets/table/target.py index e383b4cbd..da79c0f91 100644 --- a/records_mover/records/targets/table/target.py +++ b/records_mover/records/targets/table/target.py @@ -4,7 +4,7 @@ MightSupportMoveFromFileobjsSource, SupportsMoveFromDataframes, ) -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection from records_mover.records.prep import TablePrep, TargetTableDetails from records_mover.records.records_format import BaseRecordsFormat from records_mover.db import DBDriver @@ -22,7 +22,7 @@ DoMoveFromTempLocAfterFillingIt ) import logging -from typing import Callable, Optional, Dict, List, TYPE_CHECKING +from typing import Callable, Union, Optional, Dict, List, TYPE_CHECKING if TYPE_CHECKING: from records_mover.records.sources.dataframes import DataframesRecordsSource @@ -39,16 +39,20 @@ def __init__(self, schema_name: str, table_name: str, db_engine: Engine, - db_driver: Callable[[Engine], DBDriver], + db_driver: Callable[[Optional[Union['Engine', 'Connection']], + Optional[Connection], + Optional[Engine]], DBDriver], add_user_perms_for: Optional[Dict[str, List[str]]] = None, add_group_perms_for: Optional[Dict[str, List[str]]] = None, existing_table_handling: ExistingTableHandling = ExistingTableHandling.DELETE_AND_OVERWRITE, - drop_and_recreate_on_load_error: bool = False) -> None: + drop_and_recreate_on_load_error: bool = False, + db_conn: Optional[Connection] = None) -> None: self.schema_name = schema_name self.table_name = table_name self.db_driver = db_driver # type: ignore self.db_engine = db_engine + self.db_conn = db_conn self.add_user_perms_for = add_user_perms_for self.add_group_perms_for = add_group_perms_for self.existing_table_handling = existing_table_handling @@ -80,12 +84,16 @@ def move_from_fileobjs_source(self, processing_instructions).move() def can_move_from_fileobjs_source(self) -> bool: - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader_from_fileobj() return loader is not None def can_move_directly_from_scheme(self, scheme: str) -> bool: - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader() if loader is None: # can't bulk load at all, so can't load direct! @@ -94,7 +102,9 @@ def can_move_directly_from_scheme(self, scheme: str) -> bool: return loader.best_scheme_to_load_from() == scheme def known_supported_records_formats(self) -> List[BaseRecordsFormat]: - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader() if loader is None: logger.warning(f"No loader configured for this database type ({self.db_engine.name})") @@ -105,7 +115,9 @@ def can_move_from_format(self, source_records_format: BaseRecordsFormat) -> bool: """Return true if writing the specified format satisfies our format needs""" - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader() if loader is None: logger.warning(f"No loader configured for this database type ({self.db_engine.name})") @@ -113,7 +125,9 @@ def can_move_from_format(self, return loader.can_load_this_format(source_records_format) def can_move_from_temp_loc_after_filling_it(self) -> bool: - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader() if loader is None: logger.warning(f"No loader configured for this database type ({self.db_engine.name})") @@ -130,7 +144,9 @@ def can_move_from_temp_loc_after_filling_it(self) -> bool: return has_scratch_location def temporary_loadable_directory_scheme(self) -> str: - driver = self.db_driver(self.db_engine) + driver = self.db_driver(None, + db_engine=self.db_engine, + db_conn=self.db_conn) loader = driver.loader() if loader is None: raise TypeError("Please check can_move_from_temp_loc_after_filling_it() " diff --git a/records_mover/session.py b/records_mover/session.py index 00447153e..f54b8a303 100644 --- a/records_mover/session.py +++ b/records_mover/session.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from .db import DBDriver # noqa - from sqlalchemy.engine import Engine # noqa + from sqlalchemy.engine import Engine, Connection # noqa import boto3 # noqa import google.cloud.storage # noqa @@ -261,7 +261,10 @@ def get_db_engine(self, db_facts = creds_provider.db_facts(db_creds_name) return engine_from_db_facts(db_facts) - def db_driver(self, db: 'Engine') -> 'DBDriver': + def db_driver(self, + db: Optional[Union['Engine', 'Connection']], + db_conn: Optional['Connection'] = None, + db_engine: Optional['Engine'] = None) -> 'DBDriver': from .db.factory import db_driver kwargs = {} @@ -282,6 +285,8 @@ def db_driver(self, db: 'Engine') -> 'DBDriver': logger.debug('google.cloud.storage not installed', exc_info=True) return db_driver(db=db, + db_conn=db_conn, + db_engine=db_engine, url_resolver=self.url_resolver, **kwargs) diff --git a/records_mover/utils/json_schema.py b/records_mover/utils/json_schema.py index 48c0e0b91..94a3c3dbe 100644 --- a/records_mover/utils/json_schema.py +++ b/records_mover/utils/json_schema.py @@ -52,7 +52,8 @@ def is_optional_type(tp: PythonType) -> bool: def is_impractical_type(python_type: PythonType) -> bool: # can't write code and pass it around in JSON! return (is_callable_type(python_type) or - type(python_type) == enum.EnumMeta) + type(python_type) == enum.EnumMeta + or type(python_type) == typing.ForwardRef) def parse_python_parameter_type(name: str, diff --git a/records_mover/version.py b/records_mover/version.py index 0bb84ff29..c24ed73be 100644 --- a/records_mover/version.py +++ b/records_mover/version.py @@ -1 +1 @@ -__version__ = '1.5.3' +__version__ = '1.5.4' diff --git a/setup.py b/setup.py index eacc54fda..4eaa4bded 100755 --- a/setup.py +++ b/setup.py @@ -130,6 +130,7 @@ def initialize_options(self) -> None: 'jsonschema', # needed for directory_validator.py 'pytz', 'wheel', # needed to support legacy 'setup.py install' + 'parameterized', ] + ( pytest_dependencies + # needed for records_database_fixture retrying drop/creates on @@ -196,7 +197,7 @@ def initialize_options(self) -> None: redshift_dependencies_base = [ # sqlalchemy-redshift 0.7.7 introduced support for Parquet in UNLOAD - 'sqlalchemy-redshift>=0.7.7,<0.8.13', + 'sqlalchemy-redshift>=0.7.7', ] + aws_dependencies + db_dependencies redshift_dependencies_binary = [ @@ -301,7 +302,9 @@ def initialize_options(self) -> None: # what we support # # https://github.com/aws/aws-cli/blob/develop/setup.py - 'PyYAML>=3.10,<5.5', + # #### NEW UNFORTUNATE CONSTRAINTS ##### + # https://github.com/yaml/pyyaml/issues/724 + 'PyYAML>=3.10,<=5.3.1', # Not sure how/if interface will change in db-facts, so # let's be conservative about what we're specifying for now. 'db-facts>=4,<5', diff --git a/tests/component/db/mysql/test_mysql_db_driver.py b/tests/component/db/mysql/test_mysql_db_driver.py index 3863ff493..aaa5a7679 100644 --- a/tests/component/db/mysql/test_mysql_db_driver.py +++ b/tests/component/db/mysql/test_mysql_db_driver.py @@ -9,7 +9,7 @@ def setUp(self): self.mock_db_engine = MagicMock(name='db_engine') self.mock_url_resolver = Mock(name='url_resolver') self.mock_db_engine.engine = self.mock_db_engine - self.mysql_db_driver = MySQLDBDriver(db=self.mock_db_engine, + self.mysql_db_driver = MySQLDBDriver(db_engine=self.mock_db_engine, url_resolver=self.mock_url_resolver) def test_integer_limits(self): diff --git a/tests/component/db/mysql/test_mysql_load_options_known.py b/tests/component/db/mysql/test_mysql_load_options_known.py index 21a36d64e..3ab0264bd 100644 --- a/tests/component/db/mysql/test_mysql_load_options_known.py +++ b/tests/component/db/mysql/test_mysql_load_options_known.py @@ -9,7 +9,8 @@ class TestMySQLLoadOptionsKnown(unittest.TestCase): def test_load_known_formats(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') - loader = MySQLLoader(db=mock_db, + loader = MySQLLoader(None, + db_conn=mock_db, url_resolver=mock_url_resolver) known_load_formats = loader.known_supported_records_formats_for_load() for records_format in known_load_formats: diff --git a/tests/component/db/postgres/test_postgres_copy_from_options_load_known.py b/tests/component/db/postgres/test_postgres_copy_from_options_load_known.py index 80bd2d259..7b22bd153 100644 --- a/tests/component/db/postgres/test_postgres_copy_from_options_load_known.py +++ b/tests/component/db/postgres/test_postgres_copy_from_options_load_known.py @@ -13,7 +13,8 @@ def test_load_known_formats(self): mock_db = Mock(name='db') loader = PostgresLoader(url_resolver=mock_url_resolver, meta=mock_meta, - db=mock_db) + db=None, + db_conn=mock_db) known_load_formats = loader.known_supported_records_formats_for_load() for records_format in known_load_formats: unhandled_hints = set(records_format.hints) diff --git a/tests/component/db/postgres/test_postgres_copy_from_options_unload_known.py b/tests/component/db/postgres/test_postgres_copy_from_options_unload_known.py index 44eace1e8..f4b451345 100644 --- a/tests/component/db/postgres/test_postgres_copy_from_options_unload_known.py +++ b/tests/component/db/postgres/test_postgres_copy_from_options_unload_known.py @@ -7,7 +7,7 @@ class TestPostgresCopyOptionsUnloadKnown(unittest.TestCase): def test_unload_known_formats(self): mock_db = Mock(name='db') - loader = PostgresUnloader(db=mock_db) + loader = PostgresUnloader(None, db_conn=mock_db) known_unload_formats = loader.known_supported_records_formats_for_unload() for records_format in known_unload_formats: unhandled_hints = set(records_format.hints) diff --git a/tests/component/db/postgres/test_postgres_db_driver.py b/tests/component/db/postgres/test_postgres_db_driver.py index c63996406..78db2e28a 100644 --- a/tests/component/db/postgres/test_postgres_db_driver.py +++ b/tests/component/db/postgres/test_postgres_db_driver.py @@ -9,7 +9,8 @@ def setUp(self): self.mock_db_engine = MagicMock(name='db_engine') self.mock_url_resolver = Mock(name='url_resolver') self.mock_db_engine.engine = self.mock_db_engine - self.postgres_db_driver = PostgresDBDriver(db=self.mock_db_engine, + self.postgres_db_driver = PostgresDBDriver(None, + db_engine=self.mock_db_engine, url_resolver=self.mock_url_resolver) def test_integer_limits(self): diff --git a/tests/integration/records/mover_test_case.py b/tests/integration/records/mover_test_case.py index 65ad5d270..1ae9beba9 100644 --- a/tests/integration/records/mover_test_case.py +++ b/tests/integration/records/mover_test_case.py @@ -9,7 +9,7 @@ def __init__(self, source_db_engine: Optional[Engine] = None, file_variant: Optional[DelimitedVariant] = None) -> None: """ - :param db_engine: Target database of the records move as a SQLAlchemy Engine object. + :param target_db_engine: Target database of the records move as a SQLAlchemy Engine object. :param source_data_db_engine: Source database of the records move. None if we are loading from a file or a dataframe diff --git a/tests/integration/records/multi_db/test_dataframe_schema_sql_creation.py b/tests/integration/records/multi_db/test_dataframe_schema_sql_creation.py index 8efe7a458..8ce7bd9a9 100644 --- a/tests/integration/records/multi_db/test_dataframe_schema_sql_creation.py +++ b/tests/integration/records/multi_db/test_dataframe_schema_sql_creation.py @@ -16,7 +16,7 @@ def test_dataframe_to_int64_and_back_to_object_produces_int_columns(self) -> Non # # https://github.com/bluelabsio/records-mover/pull/103 session = Session() - engine = session.get_db_engine('demo-itest') + db_engine = session.get_db_engine('demo-itest') data = {'Population': [11190846, 1303171035, 207847528]} df = DataFrame(data, columns=['Population']) @@ -26,7 +26,7 @@ def test_dataframe_to_int64_and_back_to_object_produces_int_columns(self) -> Non source = DataframesRecordsSource(dfs=[df]) processing_instructions = ProcessingInstructions() schema = source.initial_records_schema(processing_instructions) - driver = RedshiftDBDriver(db=engine) + driver = RedshiftDBDriver(None, db_engine=db_engine) schema_sql = schema.to_schema_sql(driver=driver, schema_name='my_schema_name', table_name='my_table_name') diff --git a/tests/integration/records/multi_db/test_records_table2table.py b/tests/integration/records/multi_db/test_records_table2table.py index ed4e1574a..773d597ec 100644 --- a/tests/integration/records/multi_db/test_records_table2table.py +++ b/tests/integration/records/multi_db/test_records_table2table.py @@ -1,5 +1,7 @@ +# flake8: noqa + from records_mover.db.quoting import quote_schema_and_table -from records_mover import Session, set_stream_logging +from records_mover import Session from records_mover.records import ExistingTableHandling import logging import time @@ -8,6 +10,7 @@ from ..records_database_fixture import RecordsDatabaseFixture from ..table_validator import RecordsTableValidator from ..purge_old_test_tables import purge_old_tables +from parameterized import parameterized # type: ignore[import] logger = logging.getLogger(__name__) @@ -43,6 +46,14 @@ def schema_name(db_name): raise NotImplementedError('Teach me how to determine a schema name for ' + db_name) +def name_func(testcase_func, param_num, param): + source, target = param[0] + return f"test_move_and_verify_from_{source}_to_{target}" + + +SOURCE_TARGET_PAIRS = [(source, target) for source in DB_TYPES for target in DB_TYPES] + + class RecordsMoverTable2TableIntegrationTest(unittest.TestCase): # # actual test methods are defined dynamically after the class @@ -64,6 +75,8 @@ def move_and_verify(self, source_dbname: str, target_dbname: str) -> None: sources = records.sources source_engine = session.get_db_engine(source_dbname) target_engine = session.get_db_engine(target_dbname) + source_conn = source_engine.connect() + target_conn = target_engine.connect() source_schema_name = schema_name(source_dbname) target_schema_name = schema_name(target_dbname) source_table_name = f'itest_source_{BUILD_NUM}_{CURRENT_EPOCH}' @@ -76,11 +89,13 @@ def move_and_verify(self, source_dbname: str, target_dbname: str) -> None: existing = ExistingTableHandling.DROP_AND_RECREATE source = sources.table(schema_name=source_schema_name, table_name=source_table_name, - db_engine=source_engine) + db_engine=source_engine, + db_conn=source_conn) target = targets.table(schema_name=target_schema_name, table_name=TARGET_TABLE_NAME, db_engine=target_engine, - existing_table_handling=existing) + existing_table_handling=existing, + db_conn=target_conn) out = records.move(source, target) # redshift doesn't give reliable info on load results, so this # will be None or 1 @@ -90,36 +105,18 @@ def move_and_verify(self, source_dbname: str, target_dbname: str) -> None: validator.validate(schema_name=target_schema_name, table_name=TARGET_TABLE_NAME) - quoted_target = quote_schema_and_table(target_engine, target_schema_name, TARGET_TABLE_NAME) + quoted_target = quote_schema_and_table(None, target_schema_name, TARGET_TABLE_NAME, + db_engine=target_engine) sql = f"DROP TABLE {quoted_target}" - with target_engine.connect() as connection: - with connection.begin(): - connection.exec_driver_sql(sql) - - records_database_fixture.tear_down() + target_conn.exec_driver_sql(sql) # type: ignore[attr-defined] + # records_database_fixture.tear_down() + source_conn.close() + target_conn.close() -def create_test_func(source_name, target_name): - def source2target(self): + @parameterized.expand(SOURCE_TARGET_PAIRS, name_func=name_func) + def test_move_and_verify(self, source, target): + source_name = DB_NAMES[source] + target_name = DB_NAMES[target] + print(f"Moving from {source_name} to {target_name}") self.move_and_verify(source_name, target_name) - return source2target - - -if __name__ == '__main__': - set_stream_logging(level=logging.DEBUG) - logging.getLogger('botocore').setLevel(logging.INFO) - logging.getLogger('boto3').setLevel(logging.INFO) - logging.getLogger('urllib3').setLevel(logging.INFO) - - for source in DB_TYPES: - for target in DB_TYPES: - source_name = DB_NAMES[source] - target_name = DB_NAMES[target] - f = create_test_func(source_name, target_name) - func_name = f"test_{source}2{target}" - - setattr(RecordsMoverTable2TableIntegrationTest, - func_name, - f) - - unittest.main() diff --git a/tests/integration/records/purge_old_test_tables.py b/tests/integration/records/purge_old_test_tables.py index 7c2a0845c..cc7caa759 100755 --- a/tests/integration/records/purge_old_test_tables.py +++ b/tests/integration/records/purge_old_test_tables.py @@ -35,7 +35,9 @@ def purge_old_tables(schema_name: str, table_name_prefix: str, ] print(f"Tables to purge matching {schema_name}.{table_name_prefix}_: {purgable_table_names}") for table_name in purgable_table_names: - sql = f"DROP TABLE {quote_schema_and_table(db_engine, schema_name, table_name)}" + sql = ( + "DROP TABLE " + f"{quote_schema_and_table(None, schema_name, table_name, db_engine=db_engine)}") print(sql) with db_engine.connect() as connection: with connection.begin(): diff --git a/tests/integration/records/records_database_fixture.py b/tests/integration/records/records_database_fixture.py index 6e9638e68..685d3e0bd 100644 --- a/tests/integration/records/records_database_fixture.py +++ b/tests/integration/records/records_database_fixture.py @@ -8,7 +8,10 @@ class RecordsDatabaseFixture: def quote_schema_and_table(self, schema, table): - return quote_schema_and_table(self.engine, schema, table) + return quote_schema_and_table(None, + schema=schema, + table=table, + db_engine=self.engine) def __init__(self, db_engine, schema_name, table_name): self.engine = db_engine diff --git a/tests/integration/records/records_datetime_fixture.py b/tests/integration/records/records_datetime_fixture.py index 2597f2534..668ffd421 100644 --- a/tests/integration/records/records_datetime_fixture.py +++ b/tests/integration/records/records_datetime_fixture.py @@ -1,3 +1,5 @@ +# flake8: noqa + from records_mover.db.quoting import quote_schema_and_table from records_mover.utils.retry import bigquery_retry from .datetime_cases import ( @@ -5,27 +7,34 @@ SAMPLE_HOUR, SAMPLE_MINUTE, SAMPLE_SECOND, SAMPLE_OFFSET, SAMPLE_LONG_TZ ) from sqlalchemy import text -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection +from typing import Optional + import logging logger = logging.getLogger(__name__) class RecordsDatetimeFixture: - def __init__(self, engine: Engine, schema_name: str, table_name: str): + def __init__(self, engine: Engine, schema_name: str, table_name: str, + connection: Optional[Connection] = None): self.engine = engine self.schema_name = schema_name self.table_name = table_name + self.connection = connection def quote_schema_and_table(self, schema, table): - return quote_schema_and_table(self.engine, schema, table) + return quote_schema_and_table(None, schema, table, db_engine=self.engine) @bigquery_retry() def drop_table_if_exists(self, schema, table): sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}" - with self.engine.connect() as connection: - with connection.begin(): - connection.execute(text(sql)) + if not self.connection: + with self.engine.connect() as connection: + with connection.begin(): + connection.execute(text(sql)) + else: + self.connection.execute(text(sql)) def createDateTimeTzTable(self) -> None: if self.engine.name == 'redshift': @@ -87,9 +96,13 @@ def createDateTimeTable(self) -> None: """ # noqa else: raise NotImplementedError(f"Please teach me how to integration test {self.engine.name}") - with self.engine.connect() as connection: - with connection.begin(): - connection.exec_driver_sql(create_tables) + if not self.connection: + with self.engine.connect() as connection: + with connection.begin(): + connection.exec_driver_sql(create_tables) + else: + with self.connection.begin(): + self.connection.exec_driver_sql(create_tables) # type: ignore[attr-defined] @bigquery_retry() def createDateTable(self) -> None: @@ -120,9 +133,13 @@ def createDateTable(self) -> None: """ # noqa else: raise NotImplementedError(f"Please teach me how to integration test {self.engine.name}") - with self.engine.connect() as connection: - with connection.begin(): - connection.exec_driver_sql(create_tables) + if not self.connection: + with self.engine.connect() as connection: + with connection.begin(): + connection.exec_driver_sql(create_tables) + else: + with self.connection.begin(): + self.connection.exec_driver_sql(create_tables) # type: ignore[attr-defined] @bigquery_retry() def createTimeTable(self): @@ -153,9 +170,13 @@ def createTimeTable(self): """ # noqa else: raise NotImplementedError(f"Please teach me how to integration test {self.engine.name}") - with self.engine.connect() as connection: - with connection.begin(): - connection.exec_driver_sql(create_tables) + if not self.connection: + with self.engine.connect() as connection: + with connection.begin(): + connection.exec_driver_sql(create_tables) + else: + with self.connection.begin(): + self.connection.exec_driver_sql(create_tables) def drop_tables(self): logger.info('Dropping tables...') diff --git a/tests/integration/records/records_numeric_database_fixture.py b/tests/integration/records/records_numeric_database_fixture.py index 68d7c2f96..66c381ba9 100644 --- a/tests/integration/records/records_numeric_database_fixture.py +++ b/tests/integration/records/records_numeric_database_fixture.py @@ -132,7 +132,8 @@ def bring_up(self): connection.exec_driver_sql(statement) def quote_schema_and_table(self, schema, table): - return quote_schema_and_table(self.engine, schema, table) + return quote_schema_and_table(None, schema, table, + db_engine=self.engine) def drop_table_if_exists(self, schema, table): sql = f"DROP TABLE IF EXISTS {self.quote_schema_and_table(schema, table)}" diff --git a/tests/integration/records/single_db/base_records_test.py b/tests/integration/records/single_db/base_records_test.py index b9ecccaa9..af0b6fa4a 100644 --- a/tests/integration/records/single_db/base_records_test.py +++ b/tests/integration/records/single_db/base_records_test.py @@ -46,7 +46,8 @@ def setUp(self): default_db_creds_name=None, default_aws_creds_name=None) self.engine = self.session.get_default_db_engine() - self.driver = self.session.db_driver(self.engine) + self.connection = self.engine.connect() + self.driver = self.session.db_driver(None, db_conn=self.connection, db_engine=self.engine) if self.engine.name == 'bigquery': self.schema_name = 'bq_itest' # avoid per-table rate limits @@ -72,6 +73,8 @@ def setUp(self): def tearDown(self): self.session = None self.fixture.tear_down() + if not self.connection.closed: + self.connection.close() def table(self, schema, table): return Table(table, self.meta, schema=schema, autoload_with=self.engine) @@ -108,7 +111,8 @@ def unload_column_to_string(self, with tempfile.TemporaryDirectory() as directory_name: source = sources.table(schema_name=self.schema_name, table_name=self.table_name, - db_engine=self.engine) + db_engine=self.engine, + db_conn=self.connection) directory_url = pathlib.Path(directory_name).as_uri() + '/' target = targets.directory_from_url(output_url=directory_url, records_format=records_format) diff --git a/tests/integration/records/single_db/test_records_numeric.py b/tests/integration/records/single_db/test_records_numeric.py index 63b9656e2..c91a055cc 100644 --- a/tests/integration/records/single_db/test_records_numeric.py +++ b/tests/integration/records/single_db/test_records_numeric.py @@ -52,6 +52,9 @@ def setUp(self): table_name=self.table_name) self.numeric_fixture.tear_down() + def tearDown(self): + ... + def test_numeric_schema_fields_created(self) -> None: self.numeric_fixture.bring_up() with tempfile.TemporaryDirectory(prefix='test_records_numeric_schema') as tempdir: diff --git a/tests/integration/records/single_db/test_records_unload.py b/tests/integration/records/single_db/test_records_unload.py index cf833037b..b9d8c42ed 100644 --- a/tests/integration/records/single_db/test_records_unload.py +++ b/tests/integration/records/single_db/test_records_unload.py @@ -10,6 +10,9 @@ class RecordsUnloadIntegrationTest(BaseRecordsIntegrationTest): + def tearDown(self): + self.connection.close() + def test_unload_csv_format(self): self.unload_and_verify('delimited', 'csv') @@ -74,7 +77,8 @@ def unload(self, variant, directory, hints={}) -> None: sources = self.records.sources source = sources.table(schema_name=self.schema_name, table_name=self.table_name, - db_engine=self.engine) + db_engine=self.engine, + db_conn=self.connection) target = targets.directory_from_url(output_url=directory_url, records_format=records_format) out = self.records.move(source, target) diff --git a/tests/integration/records/single_db/test_records_unload_datetime.py b/tests/integration/records/single_db/test_records_unload_datetime.py index 2242ef90a..c12424db0 100644 --- a/tests/integration/records/single_db/test_records_unload_datetime.py +++ b/tests/integration/records/single_db/test_records_unload_datetime.py @@ -24,11 +24,12 @@ def setUp(self) -> None: super().setUp() self.datetime_fixture = RecordsDatetimeFixture(engine=self.engine, table_name=self.table_name, - schema_name=self.schema_name) + schema_name=self.schema_name, + connection=self.connection) def tearDown(self): - super().tearDown() self.datetime_fixture.drop_tables() + self.connection.close() def test_unload_date(self) -> None: self.datetime_fixture.createDateTable() @@ -73,7 +74,7 @@ def test_unload_date(self) -> None: 'dateformat': dateformat, 'compression': None, 'header-row': False, - **addl_hints, # type: ignore + **addl_hints, }) expect_pandas_failure = (not self.has_pandas()) and uses_pandas try: @@ -141,7 +142,7 @@ def test_unload_datetime(self) -> None: 'datetimeformat': datetimeformat, 'compression': None, 'header-row': False, - **addl_hints, # type: ignore + **addl_hints, }) expect_pandas_failure = (not self.has_pandas()) and uses_pandas try: @@ -215,7 +216,7 @@ def test_unload_datetimetz(self) -> None: 'datetimeformattz': datetimeformattz, 'compression': None, 'header-row': False, - **addl_hints, # type: ignore + **addl_hints, }) expect_pandas_failure = (not self.has_pandas()) and uses_pandas try: @@ -284,7 +285,7 @@ def test_unload_timeonly(self) -> None: 'timeonlyformat': timeonlyformat, 'compression': None, 'header-row': False, - **addl_hints, # type: ignore + **addl_hints, }) expect_pandas_failure = (not self.has_pandas()) and uses_pandas try: diff --git a/tests/integration/records/table_validator.py b/tests/integration/records/table_validator.py index bc1a19304..2b730466c 100644 --- a/tests/integration/records/table_validator.py +++ b/tests/integration/records/table_validator.py @@ -204,7 +204,7 @@ def validate_data_values(self, 'YYYY-MM-DD HH24:MI:SS.US TZ') as timestamptzstr FROM {schema_name}.{table_name} """) - out = connection.execute(select_sql, **params) + out = connection.execute(select_sql, params) ret_all = out.fetchall() assert 1 == len(ret_all) ret = ret_all[0] diff --git a/tests/unit/db/bigquery/test_bigquery_db_driver.py b/tests/unit/db/bigquery/test_bigquery_db_driver.py index 86acaf653..617830b5b 100644 --- a/tests/unit/db/bigquery/test_bigquery_db_driver.py +++ b/tests/unit/db/bigquery/test_bigquery_db_driver.py @@ -14,7 +14,8 @@ def setUp(self, mock_BigQueryLoader): self.mock_db_engine = MagicMock(name='db_engine') self.mock_url_resolver = Mock(name='url_resolver') self.mock_BigQueryLoader = mock_BigQueryLoader - self.bigquery_db_driver = BigQueryDBDriver(db=self.mock_db_engine, + self.bigquery_db_driver = BigQueryDBDriver(db=None, + db_engine=self.mock_db_engine, url_resolver=self.mock_url_resolver) def test_load_implemented(self): diff --git a/tests/unit/db/bigquery/test_bigquery_loader.py b/tests/unit/db/bigquery/test_bigquery_loader.py index c38a765bc..e0de2dea3 100644 --- a/tests/unit/db/bigquery/test_bigquery_loader.py +++ b/tests/unit/db/bigquery/test_bigquery_loader.py @@ -15,8 +15,8 @@ class TestBigQueryLoader(unittest.TestCase): def test_load_with_bad_schema_name(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) mock_schema = 'my_project.my_dataset.something_invalid' mock_table = Mock(name='mock_table') mock_load_plan = Mock(name='mock_load_plan') @@ -38,8 +38,8 @@ def test_load_with_bad_schema_name(self, mock_load_job_config): def test_load_with_default_project(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) mock_schema = 'my_dataset' mock_table = 'my_table' mock_load_plan = Mock(name='mock_load_plan') @@ -75,8 +75,8 @@ def test_load(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -111,8 +111,8 @@ def test_load_with_job_failure(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -142,8 +142,8 @@ def test_load_no_table(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -169,8 +169,8 @@ def test_load_from_fileobj_true(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -205,8 +205,8 @@ def test_load_from_fileobj_error(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -241,8 +241,8 @@ def test_load_with_fileobj_fallback(self, mock_load_job_config): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) mock_schema = 'my_project.my_dataset' mock_table = 'mytable' mock_load_plan = Mock(name='mock_load_plan') @@ -282,8 +282,8 @@ def test_load_with_fileobj_fallback(self, mock_load_job_config): def test_known_supported_records_formats_for_load(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.known_supported_records_formats_for_load() self.assertEqual(3, len(out)) delimited_records_format = out[0] @@ -297,8 +297,8 @@ def test_known_supported_records_formats_for_load(self): def test_temporary_gcs_directory_loc_none(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) with self.assertRaises(NoTemporaryBucketConfiguration): with bigquery_loader.temporary_gcs_directory_loc(): pass @@ -307,8 +307,8 @@ def test_temporary_loadable_directory_loc(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) with bigquery_loader.temporary_loadable_directory_loc() as loc: self.assertEqual(loc, mock_gcs_temp_base_loc.temporary_directory.return_value.__enter__. @@ -318,8 +318,8 @@ def test_temporary_gcs_directory_loc(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) with bigquery_loader.temporary_gcs_directory_loc() as loc: self.assertEqual(loc, mock_gcs_temp_base_loc.temporary_directory.return_value.__enter__. @@ -329,22 +329,22 @@ def test_has_temporary_loadable_directory_loc_true(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) self.assertTrue(bigquery_loader.has_temporary_loadable_directory_loc()) def test_temporary_loadable_directory_scheme(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) self.assertEqual('gs', bigquery_loader.temporary_loadable_directory_scheme()) def test_best_scheme_to_load_from(self): mock_db = Mock(name='db') mock_url_resolver = Mock(name='url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, db_conn=mock_db) self.assertEqual('gs', bigquery_loader.best_scheme_to_load_from()) diff --git a/tests/unit/db/bigquery/test_bigquery_loader_can_load_this_format.py b/tests/unit/db/bigquery/test_bigquery_loader_can_load_this_format.py index ced3eda19..a2b6ca5cc 100644 --- a/tests/unit/db/bigquery/test_bigquery_loader_can_load_this_format.py +++ b/tests/unit/db/bigquery/test_bigquery_loader_can_load_this_format.py @@ -29,8 +29,8 @@ def test_can_load_this_format_true(self, mock_load_plan.records_format = mock_source_records_format mock_url_resolver = Mock(name='url_resolver') mock_source_records_format.hints = {} - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.can_load_this_format(mock_source_records_format) mock_ProcessingInstructions.assert_called_with() mock_RecordsLoadPlan.\ @@ -55,8 +55,8 @@ def test_can_load_this_format_delimited_false(self, mock_url_resolver = Mock(name='url_resolver') mock_load_job_config.side_effect = NotImplementedError mock_source_records_format.hints = {} - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.can_load_this_format(mock_source_records_format) mock_ProcessingInstructions.assert_called_with() mock_RecordsLoadPlan.\ @@ -80,8 +80,8 @@ def test_can_load_this_format_true_avro(self, mock_load_plan.records_format = mock_source_records_format mock_url_resolver = Mock(name='url_resolver') mock_source_records_format.hints = {} - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.can_load_this_format(mock_source_records_format) mock_ProcessingInstructions.assert_called_with() mock_RecordsLoadPlan.\ @@ -104,8 +104,8 @@ def test_can_load_this_format_false_newformat(self, mock_load_plan.records_format = mock_source_records_format mock_url_resolver = Mock(name='url_resolver') mock_source_records_format.hints = {} - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.can_load_this_format(mock_source_records_format) mock_ProcessingInstructions.assert_called_with() mock_RecordsLoadPlan.\ @@ -128,8 +128,8 @@ def test_can_load_this_format_true_parquet(self, mock_load_plan.records_format = mock_source_records_format mock_url_resolver = Mock(name='url_resolver') mock_source_records_format.hints = {} - bigquery_loader = BigQueryLoader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=None) + bigquery_loader = BigQueryLoader(db=None, url_resolver=mock_url_resolver, + gcs_temp_base_loc=None, db_conn=mock_db) out = bigquery_loader.can_load_this_format(mock_source_records_format) mock_ProcessingInstructions.assert_called_with() mock_RecordsLoadPlan.\ diff --git a/tests/unit/db/bigquery/test_bigquery_unloader.py b/tests/unit/db/bigquery/test_bigquery_unloader.py index 5e262727e..6e4ac28a3 100644 --- a/tests/unit/db/bigquery/test_bigquery_unloader.py +++ b/tests/unit/db/bigquery/test_bigquery_unloader.py @@ -14,8 +14,11 @@ def test_can_unload_format_avro_true(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) avro_format = AvroRecordsFormat() self.assertTrue(big_query_unloader.can_unload_format(avro_format)) @@ -23,8 +26,11 @@ def test_can_unload_format_delimited_false(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) delimited_format = DelimitedRecordsFormat() self.assertFalse(big_query_unloader.can_unload_format(delimited_format)) @@ -32,32 +38,44 @@ def test_can_unload_to_scheme_gs_true(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) self.assertTrue(big_query_unloader.can_unload_to_scheme('gs')) def test_can_unload_to_scheme_other_with_temp_bucket_true(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) self.assertTrue(big_query_unloader.can_unload_to_scheme('blah')) def test_can_unload_to_scheme_other_with_no_temp_bucket_true(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) self.assertFalse(big_query_unloader.can_unload_to_scheme('blah')) def test_known_supported_records_formats_for_unload(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) self.assertEqual([type(format) for format in big_query_unloader.known_supported_records_formats_for_unload()], @@ -67,8 +85,11 @@ def test_temporary_unloadable_directory_loc_raises(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = None - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) with self.assertRaises(NoTemporaryBucketConfiguration): with big_query_unloader.temporary_unloadable_directory_loc(): pass @@ -77,8 +98,11 @@ def test_temporary_unloadable_directory_loc(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) with big_query_unloader.temporary_unloadable_directory_loc() as temp_loc: self.assertEqual(temp_loc, mock_gcs_temp_base_loc.temporary_directory.return_value.__enter__. @@ -88,8 +112,11 @@ def test_unload(self): mock_db = Mock(name='mock_db') mock_url_resolver = MagicMock(name='mock_url_resolver') mock_gcs_temp_base_loc = MagicMock(name='gcs_temp_base_loc') - big_query_unloader = BigQueryUnloader(db=mock_db, url_resolver=mock_url_resolver, - gcs_temp_base_loc=mock_gcs_temp_base_loc) + big_query_unloader = BigQueryUnloader( + db=None, + url_resolver=mock_url_resolver, + gcs_temp_base_loc=mock_gcs_temp_base_loc, + db_conn=mock_db) mock_schema = 'myproject.mydataset' mock_table = 'mytable' mock_unload_plan = Mock(name='unload_plan') diff --git a/tests/unit/db/mysql/test_loader.py b/tests/unit/db/mysql/test_loader.py index 883e5c7ef..77cb6817a 100644 --- a/tests/unit/db/mysql/test_loader.py +++ b/tests/unit/db/mysql/test_loader.py @@ -10,8 +10,9 @@ def setUp(self): self.mock_url_resolver = Mock(name='url_resolver') self.mock_db_engine = MagicMock(name='db_engine') self.mock_db_engine.engine = self.mock_db_engine - self.loader = MySQLLoader(db=self.mock_db_engine, - url_resolver=self.mock_url_resolver) + self.loader = MySQLLoader(db=None, + url_resolver=self.mock_url_resolver, + db_engine=self.mock_db_engine) @patch('records_mover.db.mysql.loader.complain_on_unhandled_hints') @patch('records_mover.db.mysql.loader.mysql_load_options') @@ -67,5 +68,5 @@ def test_load_happy_path(self, mock_table, mock_load_plan, mock_directory) - self.mock_db_engine.execute.assert_called_with(mock_sql) + self.mock_db_engine.connect.return_value.execute.assert_called_with(mock_sql) self.assertEqual(out, None) diff --git a/tests/unit/db/postgres/test_loader.py b/tests/unit/db/postgres/test_loader.py index 2d142915d..172321a96 100644 --- a/tests/unit/db/postgres/test_loader.py +++ b/tests/unit/db/postgres/test_loader.py @@ -12,7 +12,8 @@ def setUp(self): self.mock_db = MagicMock(name='db') self.loader = PostgresLoader(self.mock_url_resolver, self.mock_meta, - self.mock_db) + None, + db_conn=self.mock_db) @patch('records_mover.db.postgres.loader.quote_value') @patch('records_mover.db.postgres.loader.copy_from') @@ -57,9 +58,10 @@ def test_load_from_fileobj(self, mock_Table.assert_called_with(mock_table, self.mock_meta, schema=mock_schema, - autoload_with=self.mock_db) - mock_conn = self.mock_db.engine.begin.return_value.__enter__.return_value - mock_quote_value.assert_called_with(mock_conn, 'ISO, DATE_ORDER_STYLE') + autoload_with=self.mock_db.engine) + mock_conn = self.mock_db + mock_quote_value.assert_called_with(None, 'ISO, DATE_ORDER_STYLE', + db_engine=mock_conn.engine) str_arg = str(mock_conn.execute.call_args.args[0]) self.assertEqual(str_arg, 'SET LOCAL DateStyle = ABC') mock_copy_from.assert_called_with(mock_fileobj, @@ -110,9 +112,9 @@ def test_load_from_fileobj_default_date_order_style(self, mock_Table.assert_called_with(mock_table, self.mock_meta, schema=mock_schema, - autoload_with=self.mock_db) - mock_conn = self.mock_db.engine.begin.return_value.__enter__.return_value - mock_quote_value.assert_called_with(mock_conn, 'ISO, MDY') + autoload_with=self.mock_db.engine) + mock_conn = self.mock_db + mock_quote_value.assert_called_with(None, 'ISO, MDY', db_engine=mock_conn.engine) str_arg = str(mock_conn.execute.call_args.args[0]) self.assertEqual(str_arg, 'SET LOCAL DateStyle = ABC') mock_copy_from.assert_called_with(mock_fileobj, @@ -172,9 +174,10 @@ def test_load(self, mock_Table.assert_called_with(mock_table, self.mock_meta, schema=mock_schema, - autoload_with=self.mock_db) - mock_conn = self.mock_db.engine.begin.return_value.__enter__.return_value - mock_quote_value.assert_called_with(mock_conn, 'ISO, DATE_ORDER_STYLE') + autoload_with=self.mock_db.engine) + mock_conn = self.mock_db + mock_quote_value.assert_called_with(None, 'ISO, DATE_ORDER_STYLE', + db_engine=mock_conn.engine) str_arg = str(mock_conn.execute.call_args.args[0]) self.assertEqual(str_arg, 'SET LOCAL DateStyle = ABC') mock_copy_from.assert_called_with(mock_loc.open.return_value.__enter__.return_value, diff --git a/tests/unit/db/postgres/test_unloader.py b/tests/unit/db/postgres/test_unloader.py index 4adefb1e6..279d950f3 100644 --- a/tests/unit/db/postgres/test_unloader.py +++ b/tests/unit/db/postgres/test_unloader.py @@ -8,7 +8,7 @@ class TestPostgresUnloader(unittest.TestCase): def setUp(self): self.mock_url_resolver = Mock(name='url_resolver') self.mock_db = MagicMock(name='db') - self.unloader = PostgresUnloader(self.mock_db) + self.unloader = PostgresUnloader(None, db_conn=self.mock_db) @patch('records_mover.db.postgres.unloader.quote_value') @patch('records_mover.db.postgres.unloader.copy_to') @@ -55,9 +55,10 @@ def test_unload(self, mock_Table.assert_called_with(mock_table, ANY, schema=mock_schema, - autoload_with=self.mock_db) - mock_conn = self.mock_db.engine.begin.return_value.__enter__.return_value - mock_quote_value.assert_called_with(mock_conn, 'DATE_OUTPUT_STYLE, DATE_ORDER_STYLE') + autoload_with=self.mock_db.engine) + mock_conn = self.mock_db + mock_quote_value.assert_called_with(None, 'DATE_OUTPUT_STYLE, DATE_ORDER_STYLE', + db_engine=mock_conn.engine) str_arg = str(mock_conn.execute.call_args.args[0]) self.assertEqual(str_arg, 'SET LOCAL DateStyle = ABC') mock_fileobj = mock_directory.loc.file_in_this_directory.return_value.open.\ @@ -112,9 +113,10 @@ def test_unload_default_date_order_style(self, mock_Table.assert_called_with(mock_table, ANY, schema=mock_schema, - autoload_with=self.mock_db) - mock_conn = self.mock_db.engine.begin.return_value.__enter__.return_value - mock_quote_value.assert_called_with(mock_conn, 'DATE_OUTPUT_STYLE, MDY') + autoload_with=self.mock_db.engine) + mock_conn = self.mock_db + mock_quote_value.assert_called_with(None, 'DATE_OUTPUT_STYLE, MDY', + db_engine=mock_conn.engine) str_arg = str(mock_conn.execute.call_args.args[0]) self.assertEqual(str_arg, 'SET LOCAL DateStyle = ABC') mock_fileobj = mock_directory.loc.file_in_this_directory.return_value.open.\ diff --git a/tests/unit/db/redshift/base_test_redshift_db_driver.py b/tests/unit/db/redshift/base_test_redshift_db_driver.py index d3d7ca786..df67fc8a6 100644 --- a/tests/unit/db/redshift/base_test_redshift_db_driver.py +++ b/tests/unit/db/redshift/base_test_redshift_db_driver.py @@ -22,8 +22,9 @@ def setUp(self): quote.return_value.\ __add__.return_value.\ __add__.return_value = 'myschema.mytable' - self.redshift_db_driver = RedshiftDBDriver(db=self.mock_db_engine, - s3_temp_base_loc=self.mock_s3_temp_base_loc) + self.redshift_db_driver = RedshiftDBDriver(db=None, + s3_temp_base_loc=self.mock_s3_temp_base_loc, + db_engine=self.mock_db_engine) self.mock_db_engine.dialect.ddl_compiler.return_value = ( '\nCREATE TABLE myschema.mytable (\n)\n\n') mock_records_unload_plan = create_autospec(RecordsUnloadPlan) diff --git a/tests/unit/db/redshift/test_loader.py b/tests/unit/db/redshift/test_loader.py index 5858eb08f..19393ea76 100644 --- a/tests/unit/db/redshift/test_loader.py +++ b/tests/unit/db/redshift/test_loader.py @@ -12,9 +12,10 @@ def setUp(self): self.s3_temp_base_loc = MagicMock(name='s3_temp_base_loc') self.redshift_loader =\ - RedshiftLoader(db=self.mock_db, + RedshiftLoader(db=None, meta=self.mock_meta, - s3_temp_base_loc=self.s3_temp_base_loc) + s3_temp_base_loc=self.s3_temp_base_loc, + db_conn=self.mock_db) @patch('records_mover.db.redshift.loader.redshift_copy_options') @patch('records_mover.db.redshift.loader.ProcessingInstructions') diff --git a/tests/unit/db/redshift/test_loader_temporary_bucket.py b/tests/unit/db/redshift/test_loader_temporary_bucket.py index 966ac6607..201cd4825 100644 --- a/tests/unit/db/redshift/test_loader_temporary_bucket.py +++ b/tests/unit/db/redshift/test_loader_temporary_bucket.py @@ -10,9 +10,10 @@ def test_temporary_s3_directory_loc_no_bucket(self): mock_meta = Mock(name='meta') redshift_loader =\ - RedshiftLoader(db=mock_db, + RedshiftLoader(db=None, meta=mock_meta, - s3_temp_base_loc=None) + s3_temp_base_loc=None, + db_conn=mock_db) with self.assertRaises(NoTemporaryBucketConfiguration): with redshift_loader.temporary_s3_directory_loc(): diff --git a/tests/unit/db/redshift/test_redshift_db_driver.py b/tests/unit/db/redshift/test_redshift_db_driver.py index 82038be94..4f94e9221 100644 --- a/tests/unit/db/redshift/test_redshift_db_driver.py +++ b/tests/unit/db/redshift/test_redshift_db_driver.py @@ -26,8 +26,11 @@ def test_set_grant_permissions_for_group(self, mock_quote_schema_and_table, mock_quote_group_name.return_value = '"a_group"' groups = {'all': ['a_group']} mock_conn = self.mock_db_engine.engine.connect.return_value.__enter__.return_value - self.redshift_db_driver.set_grant_permissions_for_groups(mock_schema, mock_table, - groups, mock_conn) + self.redshift_db_driver.set_grant_permissions_for_groups(mock_schema, + mock_table, + groups, + None, + db_conn=mock_conn) mock_conn.execute.assert_called_with( f'GRANT all ON TABLE {mock_schema}.{mock_table} TO GROUP "a_group"') diff --git a/tests/unit/db/redshift/test_redshift_db_driver_format_negotiation.py b/tests/unit/db/redshift/test_redshift_db_driver_format_negotiation.py index 98833fa80..821d228fe 100644 --- a/tests/unit/db/redshift/test_redshift_db_driver_format_negotiation.py +++ b/tests/unit/db/redshift/test_redshift_db_driver_format_negotiation.py @@ -23,8 +23,9 @@ def setUp(self, mock_RedshiftUnloader, mock_RedshiftLoader): quote.return_value.\ __add__.return_value.\ __add__.return_value = 'myschema.mytable' - self.redshift_db_driver = RedshiftDBDriver(db=mock_db_engine, - s3_temp_base_loc=mock_s3_temp_base_loc) + self.redshift_db_driver = RedshiftDBDriver(db=None, + s3_temp_base_loc=mock_s3_temp_base_loc, + db_engine=mock_db_engine,) def test_can_load_this_format(self): mock_source_records_format = Mock(name='source_records_format', spec=DelimitedRecordsFormat) diff --git a/tests/unit/db/redshift/test_redshift_db_driver_unload.py b/tests/unit/db/redshift/test_redshift_db_driver_unload.py index deee9557b..a68adf374 100644 --- a/tests/unit/db/redshift/test_redshift_db_driver_unload.py +++ b/tests/unit/db/redshift/test_redshift_db_driver_unload.py @@ -24,7 +24,6 @@ def test_unload_to_non_s3(self, hints=bluelabs_format_hints) self.mock_directory.scheme = 'mumble' self.mock_db_engine.connect.return_value \ - .__enter__.return_value \ .execute.return_value \ .scalar.return_value = 456 rows = self.redshift_db_driver.unloader().\ @@ -65,7 +64,6 @@ def test_unload(self, hints=bluelabs_format_hints) self.mock_directory.scheme = 's3' self.mock_db_engine.connect.return_value \ - .__enter__.return_value \ .execute.return_value \ .scalar.return_value = 456 rows = self.redshift_db_driver.unloader().\ diff --git a/tests/unit/db/redshift/test_sql.py b/tests/unit/db/redshift/test_sql.py index 5ba2326b6..825ca3a46 100644 --- a/tests/unit/db/redshift/test_sql.py +++ b/tests/unit/db/redshift/test_sql.py @@ -17,7 +17,7 @@ def test_schema_sql_from_admin_views_not_installed(self, mock_connection.begin.return_value \ .__enter__.return_value = None mock_connection.execute.side_effect = sqlalchemy.exc.ProgrammingError('statement', {}, {}) - out = schema_sql_from_admin_views(mock_schema, mock_table, mock_db) + out = schema_sql_from_admin_views(mock_schema, mock_table, None, db_conn=mock_connection) self.assertIsNone(out) mock_logger.debug.assert_called_with('Error while generating SQL', exc_info=True) mock_logger.warning.\ diff --git a/tests/unit/db/redshift/test_unloader.py b/tests/unit/db/redshift/test_unloader.py index 8b81f6b97..9437e6746 100644 --- a/tests/unit/db/redshift/test_unloader.py +++ b/tests/unit/db/redshift/test_unloader.py @@ -23,9 +23,10 @@ def test_can_unload_format_true(self, mock_target_records_format.hints = {} redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) out = redshift_unloader.can_unload_format(mock_target_records_format) mock_RecordsUnloadPlan.\ assert_called_with(records_format=mock_target_records_format) @@ -53,9 +54,10 @@ def test_can_unload_format_delimite_false(self, mock_redshift_unload_options.side_effect = NotImplementedError redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) out = redshift_unloader.can_unload_format(mock_target_records_format) mock_RecordsUnloadPlan.\ assert_called_with(records_format=mock_target_records_format) @@ -70,9 +72,10 @@ def test_can_unload_to_scheme_s3_true(self): mock_table = Mock(name='table') redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=None) + s3_temp_base_loc=None, + db_conn=mock_db) self.assertTrue(redshift_unloader.can_unload_to_scheme('s3')) def test_can_unload_to_scheme_file_without_temp_bucket_true(self): @@ -80,9 +83,10 @@ def test_can_unload_to_scheme_file_without_temp_bucket_true(self): mock_table = Mock(name='table') redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=None) + s3_temp_base_loc=None, + db_conn=mock_db) self.assertFalse(redshift_unloader.can_unload_to_scheme('file')) def test_can_unload_to_scheme_file_with_temp_bucket_true(self): @@ -91,9 +95,10 @@ def test_can_unload_to_scheme_file_with_temp_bucket_true(self): mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) self.assertTrue(redshift_unloader.can_unload_to_scheme('file')) def test_known_supported_records_formats_for_unload(self): @@ -102,9 +107,10 @@ def test_known_supported_records_formats_for_unload(self): mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) formats = redshift_unloader.known_supported_records_formats_for_unload() self.assertEqual([f.__class__ for f in formats], @@ -116,9 +122,10 @@ def test_temporary_unloadable_directory_loc(self): mock_s3_temp_base_loc = MagicMock(name='s3_temp_base_loc') redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) with redshift_unloader.temporary_unloadable_directory_loc() as loc: self.assertEqual(loc, mock_s3_temp_base_loc.temporary_directory.return_value.__enter__. @@ -130,9 +137,10 @@ def test_temporary_unloadable_directory_loc_unset(self): mock_s3_temp_base_loc = None redshift_unloader =\ - RedshiftUnloader(db=mock_db, + RedshiftUnloader(db=None, table=mock_table, - s3_temp_base_loc=mock_s3_temp_base_loc) + s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) with self.assertRaises(NoTemporaryBucketConfiguration): with redshift_unloader.temporary_unloadable_directory_loc(): pass diff --git a/tests/unit/db/test_db_driver.py b/tests/unit/db/test_db_driver.py index 69075441b..c2060105b 100644 --- a/tests/unit/db/test_db_driver.py +++ b/tests/unit/db/test_db_driver.py @@ -11,7 +11,8 @@ def setUp(self): self.mock_url_resolver = Mock(name='url_resolver') self.mock_s3_temp_base_loc = MagicMock(name='s3_temp_base_loc') self.mock_s3_temp_base_loc.url = 's3://fakebucket/fakedir/fakesubdir/' - self.db_driver = GenericDBDriver(db=self.mock_db_engine, + self.db_driver = GenericDBDriver(db=None, + db_engine=self.mock_db_engine, s3_temp_base_loc=self.mock_s3_temp_base_loc, url_resolver=self.mock_url_resolver, text=fake_text) @@ -69,10 +70,9 @@ def test_has_table(self): mock_schema = Mock(name='schema') mock_table = Mock(name='table') out = self.db_driver.has_table(mock_schema, mock_table) - self.assertEqual(out, self.mock_db_engine.dialect.has_table.return_value) - self.mock_db_engine.dialect.has_table.assert_called_with(self.mock_db_engine, - schema=mock_schema, - table_name=mock_table) + self.assertEqual(out, self.mock_db_engine._sa_instance_state.has_table.return_value) + self.mock_db_engine._sa_instance_state.has_table.assert_called_with(mock_table, + schema=mock_schema) @patch('records_mover.db.driver.quote_group_name') @patch('records_mover.db.driver.quote_schema_and_table') @@ -90,10 +90,12 @@ def test_set_grant_permissions_for_groups(self, self.db_driver.set_grant_permissions_for_groups(mock_schema_name, mock_table, groups, - mock_db) - mock_quote_schema_and_table.assert_called_with(self.mock_db_engine.engine, + None, + db_conn=mock_db) + mock_quote_schema_and_table.assert_called_with(None, mock_schema_name, - mock_table) + mock_table, + db_engine=self.mock_db_engine) mock_db.execute.assert_has_calls([ call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_group_name}"), call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_group_name}"), @@ -115,10 +117,12 @@ def test_set_grant_permissions_for_users(self, self.db_driver.set_grant_permissions_for_users(mock_schema_name, mock_table, users, - mock_db) - mock_quote_schema_and_table.assert_called_with(self.mock_db_engine.engine, + None, + db_conn=mock_db) + mock_quote_schema_and_table.assert_called_with(None, mock_schema_name, - mock_table) + mock_table, + db_engine=self.mock_db_engine) mock_db.execute.assert_has_calls([ call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_user_name}"), call(f"GRANT write ON TABLE {mock_schema_and_table} TO {mock_user_name}"), @@ -139,7 +143,8 @@ def test_set_grant_permissions_for_users_bobby_tables(self, self.db_driver.set_grant_permissions_for_users(mock_schema_name, mock_table, users, - mock_db) + None, + db_conn=mock_db) @patch('records_mover.db.driver.quote_user_name') @patch('records_mover.db.driver.quote_schema_and_table') @@ -156,7 +161,8 @@ def test_set_grant_permissions_for_groups_bobby_tables(self, self.db_driver.set_grant_permissions_for_groups(mock_schema_name, mock_table, groups, - mock_db) + None, + db_conn=mock_db) def test_tweak_records_schema_for_load_no_tweak(self): mock_records_schema = Mock(name='records_schema') diff --git a/tests/unit/db/test_factory.py b/tests/unit/db/test_factory.py index 368ee6967..3a40751bc 100644 --- a/tests/unit/db/test_factory.py +++ b/tests/unit/db/test_factory.py @@ -10,7 +10,7 @@ def test_db_driver_vertica(self, mock_db = Mock(name='db') mock_engine = mock_db.engine mock_engine.name = 'vertica' - out = db_driver(mock_db) + out = db_driver(None, db_conn=mock_db) self.assertEqual(out, mock_VerticaDBDriver.return_value) @patch('records_mover.db.redshift.redshift_db_driver.RedshiftDBDriver') @@ -19,7 +19,7 @@ def test_db_driver_redshift(self, mock_db = Mock(name='db') mock_engine = mock_db.engine mock_engine.name = 'redshift' - out = db_driver(mock_db) + out = db_driver(None, db_conn=mock_db) self.assertEqual(out, mock_RedshiftDBDriver.return_value) @patch('records_mover.db.bigquery.bigquery_db_driver.BigQueryDBDriver') @@ -28,7 +28,7 @@ def test_db_driver_bigquery(self, mock_db = Mock(name='db') mock_engine = mock_db.engine mock_engine.name = 'bigquery' - out = db_driver(mock_db) + out = db_driver(None, db_conn=mock_db) self.assertEqual(out, mock_BigQueryDBDriver.return_value) @patch('records_mover.db.factory.GenericDBDriver') @@ -37,5 +37,5 @@ def test_db_driver_other(self, mock_db = Mock(name='db') mock_engine = mock_db.engine mock_engine.name = 'somaskdfaksjf' - out = db_driver(mock_db) + out = db_driver(None, db_conn=mock_db) self.assertEqual(out, mock_GenericDBDriver.return_value) diff --git a/tests/unit/db/test_quoting.py b/tests/unit/db/test_quoting.py index 7697c4005..4eee5ab34 100644 --- a/tests/unit/db/test_quoting.py +++ b/tests/unit/db/test_quoting.py @@ -8,7 +8,7 @@ def test_quote_table_only(self): mock_engine = Mock() quotable_value = Mock() mock_preparer = mock_engine.dialect.preparer.return_value - quoted = quoting.quote_table_only(mock_engine, quotable_value) + quoted = quoting.quote_table_only(None, quotable_value, db_engine=mock_engine) self.assertEqual(mock_preparer.quote.return_value, quoted) mock_preparer.quote.assert_called_with(quotable_value) @@ -16,7 +16,7 @@ def test_quote_column_name(self): mock_engine = Mock() quotable_value = Mock() mock_preparer = mock_engine.dialect.preparer.return_value - quoted = quoting.quote_column_name(mock_engine, quotable_value) + quoted = quoting.quote_column_name(None, quotable_value, db_engine=mock_engine) self.assertEqual(mock_preparer.quote.return_value, quoted) mock_preparer.quote.assert_called_with(quotable_value) @@ -26,9 +26,10 @@ def test_quote_schema_and_table(self): quotable_schema = Mock() mock_preparer = mock_engine.dialect.preparer.return_value mock_preparer.quote.return_value = '"foo"' - quoted = quoting.quote_schema_and_table(mock_engine, + quoted = quoting.quote_schema_and_table(None, quotable_schema, - quotable_table) + quotable_table, + db_engine=mock_engine) self.assertEqual(mock_preparer.quote.return_value + "." + mock_preparer.quote.return_value, quoted) @@ -39,7 +40,7 @@ def test_quote_value(self): mock_engine = Mock() quotable_value = Mock() mock_preparer = mock_engine.dialect.preparer.return_value - quoted = quoting.quote_value(mock_engine, quotable_value) + quoted = quoting.quote_value(None, quotable_value, db_engine=mock_engine) self.assertEqual(mock_preparer.quote.return_value, quoted) mock_preparer.quote.assert_called_with(quotable_value) mock_engine.dialect.preparer.assert_called_with(mock_engine.dialect, diff --git a/tests/unit/db/vertica/base_test_vertica_db_driver.py b/tests/unit/db/vertica/base_test_vertica_db_driver.py index 9661fb18f..11baa015c 100644 --- a/tests/unit/db/vertica/base_test_vertica_db_driver.py +++ b/tests/unit/db/vertica/base_test_vertica_db_driver.py @@ -20,9 +20,10 @@ def setUp(self): self.mock_s3_temp_base_loc.url = 's3://fakebucket/fakedir/fakesubdir/' with patch('records_mover.db.vertica.vertica_db_driver.VerticaLoader') \ as mock_VerticaLoader: - self.vertica_db_driver = VerticaDBDriver(db=self.mock_db_engine, + self.vertica_db_driver = VerticaDBDriver(db=None, s3_temp_base_loc=self.mock_s3_temp_base_loc, - url_resolver=self.mock_url_resolver) + url_resolver=self.mock_url_resolver, + db_conn=self.mock_db_engine) self.mock_VerticaLoader = mock_VerticaLoader self.mock_vertica_loader = mock_VerticaLoader.return_value diff --git a/tests/unit/db/vertica/test_import_sql.py b/tests/unit/db/vertica/test_import_sql.py index 9b6dcf25c..3a96ef63e 100644 --- a/tests/unit/db/vertica/test_import_sql.py +++ b/tests/unit/db/vertica/test_import_sql.py @@ -16,12 +16,12 @@ class TestImportSQL(unittest.TestCase): def test_vertica_import_sql(self, mock_quote_value, mock_quote_schema_and_table): - def null_schema_table_quoter(db_engine, schema, table): + def null_schema_table_quoter(db, schema, table, db_engine=None): return f"{schema}.{table}" mock_quote_schema_and_table.side_effect = null_schema_table_quoter - def simple_value_quoter(db_engine, s): + def simple_value_quoter(db, s, db_engine=None): return f"'{s}'" mock_quote_value.side_effect = simple_value_quoter diff --git a/tests/unit/db/vertica/test_loader.py b/tests/unit/db/vertica/test_loader.py index 00aafef2f..eb81c8adf 100644 --- a/tests/unit/db/vertica/test_loader.py +++ b/tests/unit/db/vertica/test_loader.py @@ -11,7 +11,8 @@ class TestVerticaLoader(unittest.TestCase): def setUp(self): mock_url_resolver = Mock(name='url_resolver') mock_db = Mock(name='db') - self.vertica_loader = VerticaLoader(url_resolver=mock_url_resolver, db=mock_db) + self.vertica_loader = VerticaLoader(url_resolver=mock_url_resolver, db=None, + db_conn=mock_db) @patch('records_mover.db.vertica.loader.ProcessingInstructions') @patch('records_mover.db.vertica.loader.RecordsLoadPlan') diff --git a/tests/unit/db/vertica/test_unloader.py b/tests/unit/db/vertica/test_unloader.py index 4a83eada9..06ac3ec97 100644 --- a/tests/unit/db/vertica/test_unloader.py +++ b/tests/unit/db/vertica/test_unloader.py @@ -12,7 +12,8 @@ def test_can_unload_format_true(self, mock_vertica_export_options): mock_db = Mock(name='db') mock_source_records_format = Mock(name='source_records_format', spec=DelimitedRecordsFormat) mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) mock_source_records_format.hints = {} out = vertica_unloader.can_unload_format(mock_source_records_format) mock_vertica_export_options.assert_called_with(set(), ANY) @@ -23,7 +24,8 @@ def test_can_unload_format_false(self, mock_vertica_export_options): mock_db = Mock(name='db') mock_source_records_format = Mock(name='source_records_format', spec=DelimitedRecordsFormat) mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) mock_source_records_format.hints = {} mock_vertica_export_options.side_effect = NotImplementedError out = vertica_unloader.can_unload_format(mock_source_records_format) @@ -34,7 +36,8 @@ def test_known_supported_records_formats_for_unload(self): mock_db = Mock(name='db') mock_source_records_format = Mock(name='source_records_format', spec=DelimitedRecordsFormat) mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) mock_source_records_format.hints = {} out = vertica_unloader.known_supported_records_formats_for_unload() self.assertEqual(out, [DelimitedRecordsFormat(variant='vertica')]) diff --git a/tests/unit/db/vertica/test_unloader_no_aws_creds.py b/tests/unit/db/vertica/test_unloader_no_aws_creds.py index 2c2fdc684..856732c0b 100644 --- a/tests/unit/db/vertica/test_unloader_no_aws_creds.py +++ b/tests/unit/db/vertica/test_unloader_no_aws_creds.py @@ -11,7 +11,8 @@ def test_unload_to_s3_directory_with_token(self): mock_s3_temp_base_loc = Mock(name='s3_temp_base_loc') mock_out = mock_db.execute.return_value mock_out.fetchall.return_value = [] - unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) mock_table = Mock(name='table') mock_unload_plan = Mock(name='unload_plan') mock_schema = Mock(name='schema') @@ -38,7 +39,8 @@ def test_unload_with_no_aws_creds(self, .__enter__.return_value = mock_connection mock_out = mock_connection.execute.return_value mock_out.fetchall.return_value = ['awslib'] - unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) mock_table = Mock(name='table') mock_unload_plan = Mock(name='unload_plan') mock_unload_plan.records_format = Mock(spec=DelimitedRecordsFormat) @@ -68,5 +70,6 @@ def test_s3_export_available_false_no_awslib(self): .__enter__.return_value = mock_connection mock_out = mock_connection.execute.return_value mock_out.fetchall.return_value = [] - unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) self.assertEqual(False, unloader.s3_export_available()) diff --git a/tests/unit/db/vertica/test_unloader_no_s3.py b/tests/unit/db/vertica/test_unloader_no_s3.py index 7c394f3cd..3e69a91dc 100644 --- a/tests/unit/db/vertica/test_unloader_no_s3.py +++ b/tests/unit/db/vertica/test_unloader_no_s3.py @@ -1,7 +1,7 @@ from records_mover.db.vertica.unloader import VerticaUnloader from records_mover.db.errors import NoTemporaryBucketConfiguration import unittest -from mock import Mock, MagicMock +from mock import Mock class TestVerticaUnloaderNoS3(unittest.TestCase): @@ -10,42 +10,38 @@ class TestVerticaUnloaderNoS3(unittest.TestCase): def test_temporary_unloadable_directory_load_with_no_s3_temp_bucket_configured(self): mock_db = Mock(name='db') mock_s3_temp_base_loc = None - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) with self.assertRaises(NoTemporaryBucketConfiguration): with vertica_unloader.temporary_unloadable_directory_loc(): pass def test_can_unload_to_scheme_s3_but_no_s3_export_false(self): - mock_db = MagicMock(name='db') + mock_db = Mock(name='db') mock_resultset = Mock(name='resultset') - mock_connection = MagicMock(name='connection') - mock_db.connect.return_value \ - .__enter__.return_value = mock_connection - mock_connection.execute.return_value = mock_resultset + mock_db.execute.return_value = mock_resultset mock_resultset.fetchall.return_value = [] mock_s3_temp_base_loc = None - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) self.assertFalse(vertica_unloader.can_unload_to_scheme('s3')) - str_arg = str(mock_connection.execute.call_args.args[0]) - self.assertEqual(str_arg, "SELECT lib_name from user_libraries where lib_name = 'awslib'") def test_can_unload_to_scheme_s3_but_with_s3_export_true(self): - mock_db = MagicMock(name='db') + mock_db = Mock(name='db') mock_resultset = Mock(name='resultset') - mock_connection = MagicMock(name='connection') - mock_db.connect.return_value \ - .__enter__.return_value = mock_connection - mock_connection.execute.return_value = mock_resultset + mock_db.execute.return_value = mock_resultset mock_resultset.fetchall.return_value = ['awslib'] mock_s3_temp_base_loc = None - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=mock_s3_temp_base_loc) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=mock_s3_temp_base_loc, + db_conn=mock_db) self.assertTrue(vertica_unloader.can_unload_to_scheme('s3')) - str_arg = str(mock_connection.execute.call_args.args[0]) + + str_arg = str(mock_db.execute.call_args.args[0]) self.assertEqual(str_arg, "SELECT lib_name from user_libraries where lib_name = 'awslib'") def test_s3_temp_bucket_available_false(self): mock_db = Mock(name='db') - vertica_unloader = VerticaUnloader(db=mock_db, s3_temp_base_loc=None) + vertica_unloader = VerticaUnloader(db=None, s3_temp_base_loc=None, db_conn=mock_db) self.assertFalse(vertica_unloader.s3_temp_bucket_available()) diff --git a/tests/unit/db/vertica/test_vertica_db_driver.py b/tests/unit/db/vertica/test_vertica_db_driver.py index ef6415a33..e9b0ba9c5 100644 --- a/tests/unit/db/vertica/test_vertica_db_driver.py +++ b/tests/unit/db/vertica/test_vertica_db_driver.py @@ -1,5 +1,5 @@ from .base_test_vertica_db_driver import BaseTestVerticaDBDriver -from mock import Mock, MagicMock +from mock import Mock from ...records.format_hints import (vertica_format_hints) import sqlalchemy @@ -8,10 +8,7 @@ class TestVerticaDBDriver(BaseTestVerticaDBDriver): def test_unload(self): mock_result = Mock(name='result') mock_result.rows = 579 - mock_connection = MagicMock(name='connection') - self.mock_db_engine.connect.return_value \ - .__enter__.return_value = mock_connection - mock_connection.execute.return_value.fetchall.return_value = [mock_result] + self.mock_db_engine.execute.return_value.fetchall.return_value = [mock_result] self.mock_records_unload_plan.processing_instructions.fail_if_dont_understand = True self.mock_records_unload_plan.processing_instructions.fail_if_cant_handle_hint = True @@ -28,10 +25,7 @@ def test_unload(self): def test_unload_to_non_s3(self): mock_result = Mock(name='result') mock_result.rows = 579 - mock_connection = MagicMock(name='connection') - self.mock_db_engine.connect.return_value \ - .__enter__.return_value = mock_connection - mock_connection.execute.return_value.fetchall.return_value = [mock_result] + self.mock_db_engine.execute.return_value.fetchall.return_value = [mock_result] self.mock_records_unload_plan.processing_instructions.fail_if_dont_understand = True self.mock_records_unload_plan.processing_instructions.fail_if_cant_handle_hint = True @@ -52,7 +46,7 @@ def test_schema_sql(self): mock_result = Mock(name='result') self.mock_db_engine.execute.return_value.fetchall.return_value = [mock_result] sql = self.vertica_db_driver.schema_sql('myschema', 'mytable') - self.assertEqual(sql, '\nCREATE TABLE myschema.mytable (\n)\n\n') + self.assertEqual(sql, mock_result.EXPORT_OBJECTS) def test_schema_sql_but_not_from_export_objects(self): self.mock_db_engine.execute.return_value.fetchall.return_value = [] diff --git a/tests/unit/db/vertica/test_vertica_db_driver_load.py b/tests/unit/db/vertica/test_vertica_db_driver_load.py index 1ee13a73d..af7a10c32 100644 --- a/tests/unit/db/vertica/test_vertica_db_driver_load.py +++ b/tests/unit/db/vertica/test_vertica_db_driver_load.py @@ -18,9 +18,10 @@ def setUp(self): self.mock_s3_temp_base_loc = MagicMock(name='s3_temp_base_loc') self.mock_url_resolver = Mock(name='url_resolver') self.mock_s3_temp_base_loc.url = 's3://fakebucket/fakedir/fakesubdir/' - self.vertica_db_driver = VerticaDBDriver(db=self.mock_db_engine, + self.vertica_db_driver = VerticaDBDriver(db=None, s3_temp_base_loc=self.mock_s3_temp_base_loc, - url_resolver=self.mock_url_resolver) + url_resolver=self.mock_url_resolver, + db_conn=self.mock_db_engine) mock_records_unload_plan = create_autospec(RecordsUnloadPlan) mock_records_unload_plan.records_format = create_autospec(DelimitedRecordsFormat) diff --git a/tests/unit/records/schema/field/sqlalchemy/base_test_field_from_sqlalchemy_column.py b/tests/unit/records/schema/field/sqlalchemy/base_test_field_from_sqlalchemy_column.py index 273a71983..3b5bc9a45 100644 --- a/tests/unit/records/schema/field/sqlalchemy/base_test_field_from_sqlalchemy_column.py +++ b/tests/unit/records/schema/field/sqlalchemy/base_test_field_from_sqlalchemy_column.py @@ -32,7 +32,7 @@ def verify(self, out = field_from_sqlalchemy_column(self.mock_column, self.mock_driver, self.mock_rep_type) mock_RecordsSchemaFieldRepresentation.from_sqlalchemy_column.\ assert_called_with(self.mock_column, - self.mock_driver.db.dialect, + self.mock_driver.db_engine.dialect, self.mock_rep_type) expected = mock_RecordsSchemaField.return_value self.mock_field_type = field_type diff --git a/tests/unit/records/sources/test_table.py b/tests/unit/records/sources/test_table.py index ba548b4a6..0c0ac34df 100644 --- a/tests/unit/records/sources/test_table.py +++ b/tests/unit/records/sources/test_table.py @@ -30,28 +30,28 @@ def test_to_dataframes_source(self, mock_DataframesRecordsSource): mock_processing_instructions = Mock(name='processing_instructions') mock_records_schema = mock_RecordsSchema.from_db_table.return_value - mock_db = self.mock_driver.db + mock_db_engine = self.mock_driver.db_engine + mock_db_conn = self.mock_driver.db_conn mock_connection = MagicMock(name='connection') - mock_db.connect.return_value \ - .__enter__.return_value = mock_connection + mock_db_engine.connect.return_value.__enter__.return_value = mock_connection mock_column = Mock(name='column') mock_columns = [mock_column] - mock_db.dialect.get_columns.return_value = mock_columns + mock_db_engine.dialect.get_columns.return_value = mock_columns mock_quoted_table = mock_quote_schema_and_table.return_value mock_chunks = mock_read_sql.return_value with self.table_records_source.to_dataframes_source(mock_processing_instructions) as\ df_source: self.assertEqual(df_source, mock_DataframesRecordsSource.return_value) - mock_db.dialect.get_columns.assert_called_with(mock_db, - self.mock_table_name, - schema=self.mock_schema_name) + mock_db_engine.dialect.get_columns.assert_called_with(mock_db_conn, + self.mock_table_name, + schema=self.mock_schema_name) mock_RecordsSchema.from_db_table.assert_called_with(self.mock_schema_name, self.mock_table_name, driver=self.mock_driver) str_arg = str(mock_read_sql.call_args.args[0]) self.assertEqual(str_arg, f"SELECT * FROM {mock_quoted_table}") kwargs = mock_read_sql.call_args.kwargs - self.assertEqual(kwargs['con'], mock_connection) + self.assertEqual(kwargs['con'], mock_db_conn) self.assertEqual(kwargs['chunksize'], 2000000) mock_DataframesRecordsSource.\ assert_called_with(dfs=ANY, diff --git a/tests/unit/records/targets/table/test_move_from_dataframes_source.py b/tests/unit/records/targets/table/test_move_from_dataframes_source.py index 31746874d..d2a3ec90e 100644 --- a/tests/unit/records/targets/table/test_move_from_dataframes_source.py +++ b/tests/unit/records/targets/table/test_move_from_dataframes_source.py @@ -62,7 +62,7 @@ def test_move_via_insert(self, mock_prep_and_load): mock_records_schema = self.mock_dfs_source.initial_records_schema.return_value mock_schema_sql = mock_records_schema.to_schema_sql.return_value out = self.algo.move() - self.mock_tbl.db_driver.assert_called_with(self.mock_tbl.db_engine) + self.mock_tbl.db_driver.assert_called_with(db=None, db_engine=self.mock_tbl.db_engine) self.mock_dfs_source.initial_records_schema.\ assert_called_with(self.mock_processing_instructions) mock_records_schema.to_schema_sql.assert_called_with(mock_driver, diff --git a/tests/unit/records/targets/table/test_move_from_fileobjs_source.py b/tests/unit/records/targets/table/test_move_from_fileobjs_source.py index ab4289f0f..f9694cdd9 100644 --- a/tests/unit/records/targets/table/test_move_from_fileobjs_source.py +++ b/tests/unit/records/targets/table/test_move_from_fileobjs_source.py @@ -41,7 +41,7 @@ def test_move(self): mock_loader_from_fileobj = mock_driver.loader_from_fileobj.return_value mock_import_count = mock_loader_from_fileobj.load_from_fileobj.return_value out = self.algo.move() - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) self.mock_RecordsLoadPlan.\ assert_called_with(records_format=self.mock_fileobjs_source.records_format, processing_instructions=self.mock_processing_instructions) @@ -75,7 +75,7 @@ class MyException(Exception): mock_loader_from_fileobj.load_failure_exception.assert_called_with() - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) mock_tweaked_records_schema.to_schema_sql.\ assert_called_with(mock_driver, self.mock_tbl.schema_name, @@ -112,7 +112,7 @@ class MyException(Exception): out = self.algo.move() mock_loader_from_fileobj.load_failure_exception.assert_called_with() - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) mock_tweaked_records_schema.to_schema_sql.\ assert_called_with(mock_driver, self.mock_tbl.schema_name, diff --git a/tests/unit/records/targets/table/test_move_from_records_directory.py b/tests/unit/records/targets/table/test_move_from_records_directory.py index 37d03ac4c..6ad7403c8 100644 --- a/tests/unit/records/targets/table/test_move_from_records_directory.py +++ b/tests/unit/records/targets/table/test_move_from_records_directory.py @@ -36,7 +36,7 @@ def test_move_happy_path(self, mock_plan = mock_RecordsLoadPlan.return_value out = self.algo.move() self.mock_prep.prep.assert_called_with(schema_sql=mock_schema_sql, driver=mock_driver) - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) mock_driver.tweak_records_schema_for_load.\ assert_called_with(mock_schema_obj, mock_plan.records_format) mock_tweaked_records_schema.to_schema_sql.assert_called_with(mock_driver, @@ -65,7 +65,7 @@ def test_move_legacy_schema_sql(self, out = self.algo.move() self.mock_prep.prep.assert_called_with(schema_sql=mock_schema_sql, driver=mock_driver) - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) self.mock_directory.load_schema_sql_from_sql_file.assert_called_with() mock_RecordsLoadPlan.\ assert_called_with(records_format=mock_records_format, @@ -101,7 +101,7 @@ def test_move_no_override(self, self.mock_directory.load_format.\ assert_called_with(self.mock_processing_instructions.fail_if_dont_understand) self.mock_prep.prep.assert_called_with(schema_sql=mock_schema_sql, driver=mock_driver) - self.mock_tbl.db_driver.assert_called_with(mock_db) + self.mock_tbl.db_driver.assert_called_with(db=None, db_conn=mock_db) self.mock_directory.load_schema_sql_from_sql_file.assert_called_with() mock_RecordsLoadPlan.\ diff --git a/tests/unit/records/targets/table/test_move_from_temp_loc_after_filling_it.py b/tests/unit/records/targets/table/test_move_from_temp_loc_after_filling_it.py index 1a8b7cb57..66cc02545 100644 --- a/tests/unit/records/targets/table/test_move_from_temp_loc_after_filling_it.py +++ b/tests/unit/records/targets/table/test_move_from_temp_loc_after_filling_it.py @@ -31,7 +31,7 @@ def test_init(self, mock_directory = mock_RecordsDirectory.return_value out = algo.move() mock_records_source.compatible_format.assert_called_with(mock_table_target) - mock_tbl.db_driver.assert_called_with(mock_tbl.db_engine) + mock_tbl.db_driver.assert_called_with(db=None, db_engine=mock_tbl.db_engine) mock_RecordsDirectory.assert_called_with(records_loc=mock_temp_loc) mock_records_source.move_to_records_directory.\ assert_called_with(records_directory=mock_directory, diff --git a/tests/unit/records/targets/table/test_target.py b/tests/unit/records/targets/table/test_target.py index 20067ccfe..df145a13c 100644 --- a/tests/unit/records/targets/table/test_target.py +++ b/tests/unit/records/targets/table/test_target.py @@ -37,21 +37,21 @@ def setUp(self): def test_can_move_from_fileobjs_source_yes(self): self.assertTrue(self.target.can_move_from_fileobjs_source()) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) def test_can_move_directly_from_scheme_no_loader(self): mock_driver = self.mock_db_driver.return_value mock_driver.loader.return_value = None self.assertFalse(self.target.can_move_directly_from_scheme('whatever')) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) def test_known_supported_records_formats_no_loader(self): mock_driver = self.mock_db_driver.return_value mock_driver.loader.return_value = None self.assertEqual([], self.target.known_supported_records_formats()) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) def test_can_move_from_format_no_loader(self): mock_driver = self.mock_db_driver.return_value @@ -59,7 +59,7 @@ def test_can_move_from_format_no_loader(self): mock_driver.loader.return_value = None self.assertFalse(self.target.can_move_from_format(mock_source_records_format)) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) def test_can_move_from_format_with_loader_true(self): mock_driver = self.mock_db_driver.return_value @@ -68,7 +68,7 @@ def test_can_move_from_format_with_loader_true(self): mock_loader.has_temporary_loadable_directory_loc.return_value = True self.assertTrue(self.target.can_move_from_temp_loc_after_filling_it()) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) mock_loader.has_temporary_loadable_directory_loc.assert_called_with() def test_temporary_loadable_directory_schemer(self): diff --git a/tests/unit/records/targets/test_factory.py b/tests/unit/records/targets/test_factory.py index 243582d8b..5bd4366ba 100644 --- a/tests/unit/records/targets/test_factory.py +++ b/tests/unit/records/targets/test_factory.py @@ -45,7 +45,8 @@ def test_table(self, mock_TableRecordsTarget): drop_and_recreate_on_load_error=False, existing_table_handling=existing_table_handling, add_group_perms_for=None, - add_user_perms_for=None) + add_user_perms_for=None, + db_conn=None) self.assertEqual(table, mock_TableRecordsTarget.return_value) @patch('records_mover.records.targets.google_sheets.GoogleSheetsRecordsTarget') diff --git a/tests/unit/records/targets/test_spectrum.py b/tests/unit/records/targets/test_spectrum.py index e1b22de7a..bddbdbc84 100644 --- a/tests/unit/records/targets/test_spectrum.py +++ b/tests/unit/records/targets/test_spectrum.py @@ -36,7 +36,7 @@ def setUp(self, def test_init(self): self.assertEqual(self.target.records_format, self.records_format) - self.assertEqual(self.target.db, self.mock_db) + self.assertEqual(self.target.db_engine, self.mock_db) @patch('records_mover.records.targets.spectrum.quote_schema_and_table') def test_pre_load_hook_preps_bucket_with_default_prep(self, mock_quote_schema_and_table): @@ -44,9 +44,10 @@ def test_pre_load_hook_preps_bucket_with_default_prep(self, mock_quote_schema_an mock_cursor = self.target.driver.db_engine.connect.return_value.__enter__.return_value self.target.pre_load_hook() - mock_quote_schema_and_table.assert_called_with(self.target.db, + mock_quote_schema_and_table.assert_called_with(None, self.target.schema_name, - self.target.table_name) + self.target.table_name, + db_engine=self.target.db_engine) mock_cursor.execution_options.assert_called_with(isolation_level='AUTOCOMMIT') arg = mock_cursor.execute.call_args.args[0] arg_str = str(arg) diff --git a/tests/unit/records/targets/test_table_file_objects.py b/tests/unit/records/targets/test_table_file_objects.py index 1ae9a8ad5..ada5f2c73 100644 --- a/tests/unit/records/targets/test_table_file_objects.py +++ b/tests/unit/records/targets/test_table_file_objects.py @@ -37,4 +37,4 @@ def test_can_move_directly_from_scheme(self): mock_loader = mock_driver.loader.return_value mock_loader.best_scheme_to_load_from.return_value = mock_scheme self.assertEqual(True, self.table.can_move_directly_from_scheme(mock_scheme)) - self.mock_db_driver.assert_called_with(self.mock_db_engine) + self.mock_db_driver.assert_called_with(None, db_engine=self.mock_db_engine, db_conn=None) diff --git a/tests/unit/records/test_prep.py b/tests/unit/records/test_prep.py index 5ff7621d8..16b403d0b 100644 --- a/tests/unit/records/test_prep.py +++ b/tests/unit/records/test_prep.py @@ -1,5 +1,5 @@ import unittest -from mock import Mock, MagicMock, patch, call +from mock import Mock, MagicMock, patch from records_mover.records.prep import TablePrep from records_mover.records.existing_table_handling import ExistingTableHandling from records_mover.db import DBDriver @@ -12,9 +12,10 @@ def setUp(self): @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_append_implicit(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) - mock_driver.db = Mock(name='db') + mock_driver.db_conn = MagicMock(name='db') + mock_driver.db_engine = MagicMock(name='db_engine') mock_driver.has_table.return_value = True how_to_prep = ExistingTableHandling.APPEND @@ -24,9 +25,10 @@ def test_prep_table_exists_append_implicit(self, mock_quote_schema_and_table): @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_truncate_implicit(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) - mock_driver.db = Mock(name='db') + mock_driver.db_conn = MagicMock(name='db') + mock_driver.db_engine = MagicMock(name='db_engine') mock_quote_schema_and_table mock_driver.has_table.return_value = True @@ -36,17 +38,19 @@ def test_prep_table_exists_truncate_implicit(self, mock_quote_schema_and_table): self.prep.prep(mock_schema_sql, mock_driver) - mock_quote_schema_and_table.assert_called_with(mock_driver.db, + mock_quote_schema_and_table.assert_called_with(None, self.mock_tbl.schema_name, - self.mock_tbl.table_name) - str_arg = str(mock_driver.db.execute.call_args.args[0]) + self.mock_tbl.table_name, + db_engine=mock_driver.db_engine) + str_arg = str(mock_driver.db_conn.execute.call_args.args[0]) self.assertEqual(str_arg, f"TRUNCATE TABLE {mock_schema_and_table}") @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) - mock_driver.db = Mock(name='db') + mock_driver.db_conn = MagicMock(name='db') + mock_driver.db_engine = MagicMock(name='db_engine') mock_quote_schema_and_table mock_driver.has_table.return_value = True @@ -56,19 +60,21 @@ def test_prep_table_exists_delete_implicit(self, mock_quote_schema_and_table): self.prep.prep(mock_schema_sql, mock_driver) - mock_quote_schema_and_table.assert_called_with(mock_driver.db, + mock_quote_schema_and_table.assert_called_with(None, self.mock_tbl.schema_name, - self.mock_tbl.table_name) - str_arg = str(mock_driver.db.execute.call_args.args[0]) + self.mock_tbl.table_name, + db_engine=mock_driver.db_engine) + str_arg = str(mock_driver.db_conn.execute.call_args.args[0]) self.assertEqual(str_arg, f"DELETE FROM {mock_schema_and_table} WHERE true") @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_drop_implicit(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) mock_driver.db_engine = mock_driver mock_db = MagicMock(name='db') - mock_driver.db = mock_db + mock_driver.db_conn = mock_db + mock_driver.db_engine = MagicMock(name='db_engine') mock_quote_schema_and_table mock_driver.has_table.return_value = True @@ -77,35 +83,37 @@ def test_prep_table_exists_drop_implicit(self, mock_quote_schema_and_table): self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver) - mock_conn = mock_db.engine.connect.return_value.__enter__.return_value + mock_conn = mock_driver.db_engine.connect.return_value.__enter__.return_value - mock_quote_schema_and_table.assert_called_with(mock_driver.db, + mock_quote_schema_and_table.assert_called_with(None, self.mock_tbl.schema_name, - self.mock_tbl.table_name) - mock_conn.execute.assert_has_calls([ - call(f"DROP TABLE {mock_schema_and_table}"), - ]) - mock_conn.exec_driver_sql.assert_has_calls([ - call(mock_schema_sql), - ]) + self.mock_tbl.table_name, + db_engine=mock_driver.db_engine) + str_args = [str(call_arg.args[0]) for call_arg in mock_conn.execute.call_args_list] + drop_table_str_arg, mock_schema_sql_str_arg = str_args[0], str_args[1] + self.assertEqual(drop_table_str_arg, f"DROP TABLE {mock_schema_and_table}") + self.assertEqual(mock_schema_sql_str_arg, mock_schema_sql) mock_driver.set_grant_permissions_for_groups.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_group_perms_for, - mock_driver) + None, + db_conn=mock_conn) mock_driver.set_grant_permissions_for_users.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_user_perms_for, - mock_driver) + None, + db_conn=mock_conn) @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_not_exists(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) mock_driver.db_engine = mock_driver mock_db = MagicMock(name='db') - mock_driver.db = mock_db + mock_driver.db_conn = mock_db + mock_driver.db_engine = MagicMock(name='db_engine') mock_quote_schema_and_table mock_driver.has_table.return_value = False @@ -113,29 +121,32 @@ def test_prep_table_not_exists(self, mock_quote_schema_and_table): self.mock_tbl.existing_table_handling = how_to_prep self.prep.prep(mock_schema_sql, mock_driver) - mock_conn = mock_db.engine.connect.return_value.__enter__.return_value - - mock_conn.exec_driver_sql.assert_has_calls([ - call(mock_schema_sql), - ]) + mock_conn = mock_driver.db_engine.connect.return_value.__enter__.return_value + print(mock_conn.execute) + str_args = [str(call_arg.args[0]) for call_arg in mock_conn.execute.call_args_list] + mock_schema_sql_str_arg = str_args[0] + self.assertEqual(mock_schema_sql_str_arg, mock_schema_sql) mock_driver.set_grant_permissions_for_groups.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_group_perms_for, - mock_driver) + None, + db_conn=mock_conn) mock_driver.set_grant_permissions_for_users.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_user_perms_for, - mock_driver) + None, + db_conn=mock_conn) @patch('records_mover.records.prep.quote_schema_and_table') def test_prep_table_exists_drop_explicit(self, mock_quote_schema_and_table): - mock_schema_sql = Mock(name='schema_sql') + mock_schema_sql = 'mock_schema_sql' mock_driver = Mock(name='driver', spec=DBDriver) - mock_driver.db_engine = mock_driver mock_db = MagicMock(name='db') - mock_driver.db = mock_db + mock_driver.db_conn = mock_db + mock_db_engine = MagicMock(name='db_engine') + mock_driver.db_engine = mock_db_engine mock_quote_schema_and_table mock_driver.has_table.return_value = True @@ -145,24 +156,24 @@ def test_prep_table_exists_drop_explicit(self, mock_quote_schema_and_table): self.prep.prep(mock_schema_sql, mock_driver, existing_table_handling=ExistingTableHandling.DROP_AND_RECREATE) - mock_conn = mock_db.engine.connect.return_value.__enter__.return_value - - mock_quote_schema_and_table.assert_called_with(mock_driver.db, + mock_conn = mock_db_engine.connect.return_value.__enter__.return_value + mock_quote_schema_and_table.assert_called_with(None, self.mock_tbl.schema_name, - self.mock_tbl.table_name) - mock_conn.execute.assert_has_calls([ - call(f"DROP TABLE {mock_schema_and_table}"), - ]) - mock_conn.exec_driver_sql.assert_has_calls([ - call(mock_schema_sql), - ]) + self.mock_tbl.table_name, + db_engine=mock_driver.db_engine) + str_args = [str(call_arg.args[0]) for call_arg in mock_conn.execute.call_args_list] + drop_table_str_arg, mock_schema_sql_str_arg = str_args[0], str_args[1] + self.assertEqual(drop_table_str_arg, f"DROP TABLE {mock_schema_and_table}") + self.assertEqual(mock_schema_sql_str_arg, mock_schema_sql) mock_driver.set_grant_permissions_for_groups.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_group_perms_for, - mock_driver) + None, + db_conn=mock_conn) mock_driver.set_grant_permissions_for_users.\ assert_called_with(self.mock_tbl.schema_name, self.mock_tbl.table_name, self.mock_tbl.add_user_perms_for, - mock_driver) + None, + db_conn=mock_conn) diff --git a/tests/unit/test_records_hook.py b/tests/unit/test_records_hook.py index 5339067e1..28b5a7e49 100644 --- a/tests/unit/test_records_hook.py +++ b/tests/unit/test_records_hook.py @@ -24,5 +24,5 @@ def test_get_conn(self, self.assertEqual(out, mock_Records.return_value) args, kwargs = mock_Records.call_args mock_db = Mock(name='db') - self.assertEqual(mock_db_driver.return_value, kwargs['db_driver'](mock_db)) + self.assertEqual(mock_db_driver.return_value, kwargs['db_driver'](db_conn=mock_db)) self.assertEqual(mock_UrlResolver.return_value, kwargs['url_resolver']) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index ba8204993..9e8ceb39a 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -61,13 +61,15 @@ def test_db_driver(self, mock_s3_temp_base_loc = mock_url_resolver.directory_url.return_value mock_gcs_temp_base_loc = mock_url_resolver.directory_url.return_value session = Session(creds=mock_creds) - out = session.db_driver(mock_db) + out = session.db_driver(None, db_conn=mock_db) self.assertEqual(out, mock_db_driver.return_value) mock_url_resolver.directory_url.assert_has_calls( [call(mock_scratch_s3_url), call(mock_scratch_gcs_url)] ) - mock_db_driver.assert_called_with(db=mock_db, + mock_db_driver.assert_called_with(db=None, + db_conn=mock_db, + db_engine=None, url_resolver=mock_url_resolver, s3_temp_base_loc=mock_s3_temp_base_loc, gcs_temp_base_loc=mock_gcs_temp_base_loc)