Skip to content

Commit

Permalink
Merge pull request #250 from bluelabsio/RM-96-rewrite-DBDriver-so-cur…
Browse files Browse the repository at this point in the history
…rent-db-becomes-a-connection-db_engine-takes-all-Engine-responsibilities

RM-96-rewrite-DBDriver-so-current-db-becomes-a-connection-db_engine-takes-all-Engine-responsibilities
  • Loading branch information
ryantimjohn authored Aug 30, 2023
2 parents 8ef1ff3 + 4aea0c0 commit 6eb5217
Show file tree
Hide file tree
Showing 105 changed files with 1,113 additions and 654 deletions.
26 changes: 8 additions & 18 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ commands:
default: false
python_version:
type: string
description: "Version of python to test against"
description: "Version of python to test against."
pandas_version:
type: string
description: "Version of pandas to test against, or empty string for none"
Expand All @@ -34,7 +34,7 @@ commands:
default: ""
steps:
- restore_cache:
key: deps-v9-<<parameters.python_version>>-<<parameters.pandas_version>>-<<parameters.extras>>-<<parameters.include_dev_dependencies>>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
key: deps-v10-<<parameters.python_version>>-<<parameters.pandas_version>>-<<parameters.extras>>-<<parameters.include_dev_dependencies>>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
- run:
name: Install python deps in venv
environment:
Expand Down Expand Up @@ -73,7 +73,7 @@ commands:
fi
fi
- save_cache:
key: deps-v9-<<parameters.python_version>>-<<parameters.pandas_version>>-<<parameters.extras>>-<<parameters.include_dev_dependencies>>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
key: deps-v10-<<parameters.python_version>>-<<parameters.pandas_version>>-<<parameters.extras>>-<<parameters.include_dev_dependencies>>-{{ .Branch }}-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
paths:
- "venv"
wait_for_db:
Expand Down Expand Up @@ -182,7 +182,7 @@ jobs:
- slack/notify:
event: fail
branch_pattern: main
channel: engineering-general
channel: records-mover-ci-builds
mentions: '<@bruno.castrokarney>'
template: basic_fail_1

Expand Down Expand Up @@ -467,15 +467,6 @@ workflows:
# We try to test against all non-end-of-life Python versions:
#
# https://devguide.python.org/devcycle/#end-of-life-branches
- test:
name: test-3.7
extras: '[unittest,typecheck]'
python_version: "3.7"
pandas_version: ">=1.3.5,<2"
context: slack-secrets
filters:
tags:
only: /v\d+\.\d+\.\d+(-[\w]+)?/
- test:
name: test-3.8
extras: '[unittest,typecheck]'
Expand Down Expand Up @@ -690,8 +681,8 @@ workflows:
export DB_FACTS_PATH=${PWD}/tests/integration/circleci-dbfacts.yml
export RECORDS_MOVER_SESSION_TYPE=env
mkdir -p test-reports/itest
cd tests/integration
python3 -m records.multi_db.test_records_table2table
cd tests/integration/records/multi_db
pytest -vvv test_records_table2table.py
- integration_test_with_dbs:
name: tbl2tbl-itest-old-sqlalchemy
extras: '[literally_every_single_database_binary,itest]'
Expand All @@ -704,8 +695,8 @@ workflows:
export DB_FACTS_PATH=${PWD}/tests/integration/circleci-dbfacts.yml
export RECORDS_MOVER_SESSION_TYPE=env
mkdir -p test-reports/itest
cd tests/integration
python3 -m records.multi_db.test_records_table2table
cd tests/integration/records/multi_db
pytest -vvv test_records_table2table.py
- integration_test_with_dbs:
name: redshift-test-dataframe-schema-sql-creation
extras: '[redshift-binary,itest]'
Expand Down Expand Up @@ -746,7 +737,6 @@ workflows:
- PyPI
- slack-secrets
requires:
- test-3.7
- test-3.8
- test-3.9
- cli-extra-test
Expand Down
2 changes: 1 addition & 1 deletion metrics/flake8_high_water_mark
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1
0
13 changes: 9 additions & 4 deletions records_mover/airflow/hooks/records_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from records_mover.db import DBDriver
from records_mover.url.resolver import UrlResolver
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, Union, List, TYPE_CHECKING
import sqlalchemy
from ...check_db_conn_engine import check_db_conn_engine

try:
# Works with Airflow 1
Expand Down Expand Up @@ -50,11 +51,15 @@ def _url_resolver(self) -> UrlResolver:
gcs_client_getter=lambda: None,
gcp_credentials_getter=lambda: None)

def _db_driver(self, db: sqlalchemy.engine.Engine) -> DBDriver:
def _db_driver(self, db: Optional[Union[sqlalchemy.engine.Engine,
sqlalchemy.engine.Connection]] = None,
db_conn: Optional[sqlalchemy.engine.Connection] = None,
db_engine: Optional[sqlalchemy.engine.Engine] = None) -> DBDriver:
s3_temp_base_loc = (self._url_resolver.directory_url(self._s3_temp_base_url)
if self._s3_temp_base_url else None)

return db_driver(db=db, url_resolver=self._url_resolver, s3_temp_base_loc=s3_temp_base_loc)
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
return db_driver(db=db, db_conn=db_conn, db_engine=db_engine,
url_resolver=self._url_resolver, s3_temp_base_loc=s3_temp_base_loc)

@property
def _s3_temp_base_url(self) -> Optional[str]:
Expand Down
31 changes: 31 additions & 0 deletions records_mover/check_db_conn_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# flake8: noqa

import sqlalchemy
from typing import Union, Optional, Tuple
import warnings


def check_db_conn_engine(db: Optional[Union[sqlalchemy.engine.Engine,
sqlalchemy.engine.Connection]] = None,
db_conn: Optional[sqlalchemy.engine.Connection] = None,
db_engine: Optional[sqlalchemy.engine.Engine] = None) -> \
Tuple[Optional[Union[sqlalchemy.engine.Engine,
sqlalchemy.engine.Connection]],
Optional[sqlalchemy.engine.Connection],
sqlalchemy.engine.Engine]:
if db:
warnings.warn("The db argument is deprecated and will be"
"removed in future releases.\n"
"Please use db_conn for Connection objects and db_engine for Engine"
"objects.",
DeprecationWarning)
if not (db or db_conn or db_engine):
raise ValueError("Either db, db_conn, or db_engine must be provided as arguments")
if isinstance(db, sqlalchemy.engine.Connection) and not db_conn:
db_conn = db
if isinstance(db, sqlalchemy.engine.Engine) and not db_engine:
db_engine = db
if not db_engine:
print("db_engine is not provided, so we're creating one from db_conn")
db_engine = db_conn.engine # type: ignore[union-attr]
return (db, db_conn, db_engine) # type: ignore[return-value]
2 changes: 1 addition & 1 deletion records_mover/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
'create_sqlalchemy_url',
]

from .driver import DBDriver # noqa
from .driver import DBDriver # noqa
from .errors import LoadError # noqa
from .connect import create_sqlalchemy_url
16 changes: 11 additions & 5 deletions records_mover/db/bigquery/bigquery_db_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ...records.records_format import BaseRecordsFormat, ParquetRecordsFormat, AvroRecordsFormat
from ...utils.limits import INT64_MAX, INT64_MIN, FLOAT64_SIGNIFICAND_BITS, num_digits
import re
from typing import Optional, Tuple
from typing import Union, Optional, Tuple
from ...url.resolver import UrlResolver
import sqlalchemy
from .loader import BigQueryLoader
Expand All @@ -19,17 +19,23 @@

class BigQueryDBDriver(DBDriver):
def __init__(self,
db: sqlalchemy.engine.Engine,
db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]],
url_resolver: UrlResolver,
gcs_temp_base_loc: Optional[BaseDirectoryUrl] = None,
db_conn: Optional[sqlalchemy.engine.Connection] = None,
db_engine: Optional[sqlalchemy.engine.Engine] = None,
**kwargs: object) -> None:
super().__init__(db)
super().__init__(db, db_conn, db_engine)
self._bigquery_loader =\
BigQueryLoader(db=self.db,
BigQueryLoader(db=db,
db_conn=db_conn,
db_engine=db_engine,
url_resolver=url_resolver,
gcs_temp_base_loc=gcs_temp_base_loc)
self._bigquery_unloader =\
BigQueryUnloader(db=self.db,
BigQueryUnloader(db=db,
db_conn=db_conn,
db_engine=db_engine,
url_resolver=url_resolver,
gcs_temp_base_loc=gcs_temp_base_loc)

Expand Down
37 changes: 31 additions & 6 deletions records_mover/db/bigquery/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import List, IO, Tuple, Optional, Iterator
from typing import Union, List, IO, Tuple, Optional, Iterator
from ...url import BaseDirectoryUrl
from ...records.delimited import complain_on_unhandled_hints
import pprint
Expand All @@ -19,19 +19,41 @@
import logging
from ..loader import LoaderFromFileobj
from ..errors import NoTemporaryBucketConfiguration
from ...check_db_conn_engine import check_db_conn_engine

logger = logging.getLogger(__name__)


class BigQueryLoader(LoaderFromFileobj):
def __init__(self,
db: sqlalchemy.engine.Engine,
db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]],
url_resolver: UrlResolver,
gcs_temp_base_loc: Optional[BaseDirectoryUrl])\
-> None:
gcs_temp_base_loc: Optional[BaseDirectoryUrl],
db_conn: Optional[sqlalchemy.engine.Connection] = None,
db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None:
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
self.db = db
self._db_conn = db_conn
self.db_engine = db_engine
self.url_resolver = url_resolver
self.gcs_temp_base_loc = gcs_temp_base_loc
self.conn_opened_here = False

def get_db_conn(self) -> sqlalchemy.engine.Connection:
if self._db_conn is None:
self._db_conn = self.db_engine.connect()
self.conn_opened_here = True
logger.debug(f"Opened connection to database within {self} because none was provided.")
return self._db_conn

def set_db_conn(self, db_conn: Optional[sqlalchemy.engine.Connection]) -> None:
self._db_conn = db_conn

def del_db_conn(self) -> None:
if self.conn_opened_here:
self.db_conn.close()

db_conn = property(get_db_conn, set_db_conn, del_db_conn)

def best_scheme_to_load_from(self) -> str:
return 'gs'
Expand Down Expand Up @@ -77,7 +99,7 @@ def load_from_fileobj(self, schema: str, table: str,
logger.info("Loading from fileobj into BigQuery")
# https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table
connection: Connection =\
self.db.engine.raw_connection().connection
self.db_engine.raw_connection().connection
# https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html
client: Client = connection._client
project_id, dataset_id = self._parse_bigquery_schema_name(schema)
Expand Down Expand Up @@ -122,7 +144,7 @@ def load(self,

logger.info("Loading from records directory into BigQuery")
# https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table
connection: Connection = self.db.engine.raw_connection().connection
connection: Connection = self.db_engine.raw_connection().connection
# https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html
client: Client = connection._client
project_id, dataset_id = self._parse_bigquery_schema_name(schema)
Expand Down Expand Up @@ -187,3 +209,6 @@ def known_supported_records_formats_for_load(self) -> List[BaseRecordsFormat]:
ParquetRecordsFormat(),
AvroRecordsFormat()
]

def __del__(self) -> None:
self.del_db_conn()
16 changes: 9 additions & 7 deletions records_mover/db/bigquery/unloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sqlalchemy
from contextlib import contextmanager
from typing import List, Iterator, Optional, Tuple
from typing import List, Iterator, Optional, Union, Tuple
import logging
from google.cloud.bigquery.dbapi.connection import Connection
from google.cloud.bigquery.client import Client
Expand All @@ -12,20 +12,22 @@
from records_mover.records.unload_plan import RecordsUnloadPlan
from records_mover.records.records_directory import RecordsDirectory
from records_mover.db.errors import NoTemporaryBucketConfiguration
from ...check_db_conn_engine import check_db_conn_engine

logger = logging.getLogger(__name__)


class BigQueryUnloader(Unloader):
def __init__(self,
db: sqlalchemy.engine.Engine,
db: Optional[Union[sqlalchemy.engine.Connection, sqlalchemy.engine.Engine]],
url_resolver: UrlResolver,
gcs_temp_base_loc: Optional[BaseDirectoryUrl])\
-> None:
self.db = db
gcs_temp_base_loc: Optional[BaseDirectoryUrl],
db_conn: Optional[sqlalchemy.engine.Connection] = None,
db_engine: Optional[sqlalchemy.engine.Engine] = None) -> None:
db, db_conn, db_engine = check_db_conn_engine(db=db, db_conn=db_conn, db_engine=db_engine)
self.url_resolver = url_resolver
self.gcs_temp_base_loc = gcs_temp_base_loc
super().__init__(db=db)
super().__init__(db=db, db_conn=db_conn, db_engine=db_engine)

def can_unload_format(self, target_records_format: BaseRecordsFormat) -> bool:
if isinstance(target_records_format, AvroRecordsFormat):
Expand Down Expand Up @@ -93,7 +95,7 @@ def unload(self,
logger.info("Loading from records directory into BigQuery")
# https://googleapis.github.io/google-cloud-python/latest/bigquery/usage/tables.html#creating-a-table
connection: Connection =\
self.db.engine.raw_connection().connection
self.db_engine.raw_connection().connection
# https://google-cloud.readthedocs.io/en/latest/bigquery/generated/google.cloud.bigquery.client.Client.html
client: Client = connection._client
project_id, dataset_id = self._parse_bigquery_schema_name(schema)
Expand Down
Loading

0 comments on commit 6eb5217

Please sign in to comment.