Skip to content

Commit

Permalink
Merge pull request #193 from bluelabsio/RM-34-upgrade-syntax-to-suppo…
Browse files Browse the repository at this point in the history
…rt-airflow-2-0

RM-34-upgrade-syntax-to-support-airflow-2-0
  • Loading branch information
ryantimjohn authored Jan 20, 2023
2 parents b2ad63f + acea0ed commit 62a31a3
Show file tree
Hide file tree
Showing 21 changed files with 94 additions and 102 deletions.
20 changes: 10 additions & 10 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -604,16 +604,16 @@ workflows:
tags:
only: /v\d+\.\d+\.\d+(-[\w]+)?/
- integration_test:
name: bigquery-no-gcs-itest
extras: '[bigquery,itest]'
python_version: "3.9"
db_name: bltoolsdevbq-bq_itest
include_gcs_scratch_bucket: false
requires:
- redshift-s3-itest
filters:
tags:
only: /v\d+\.\d+\.\d+(-[\w]+)?/
name: bigquery-no-gcs-itest
extras: '[bigquery,itest]'
python_version: "3.9"
db_name: bltoolsdevbq-bq_itest
requires:
- redshift-s3-itest
include_gcs_scratch_bucket: false
filters:
tags:
only: /v\d+\.\d+\.\d+(-[\w]+)?/
- integration_test:
name: bigquery-gcs-itest
extras: '[bigquery,itest]'
Expand Down
2 changes: 1 addition & 1 deletion metrics/bigfiles_high_water_mark
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1138
1131
2 changes: 1 addition & 1 deletion metrics/coverage_high_water_mark
Original file line number Diff line number Diff line change
@@ -1 +1 @@
93.0000
93.6400
2 changes: 1 addition & 1 deletion metrics/mypy_high_water_mark
Original file line number Diff line number Diff line change
@@ -1 +1 @@
92.2900
92.3400
6 changes: 3 additions & 3 deletions records_mover/airflow/hooks/google_cloud_credentials_hook.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from typing import Iterable, Optional, TYPE_CHECKING
if TYPE_CHECKING:
# see the 'gsheets' extras_require option in setup.py - needed for this!
import google.auth.credentials # noqa


class GoogleCloudCredentialsHook(GoogleCloudBaseHook):
class GoogleCloudCredentialsHook(GoogleBaseHook):
def get_conn(self) -> 'google.auth.credentials.Credentials':
return self._get_credentials()

def scopes(self) -> Iterable[str]:
def scopes(self) -> Iterable[str]: # type: ignore
scope: Optional[str] = self._get_field('scope', None)
scopes: Iterable[str]
if scope is not None:
Expand Down
6 changes: 3 additions & 3 deletions records_mover/airflow/hooks/records_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from records_mover.db.factory import db_driver
from records_mover.db import DBDriver
from records_mover.url.resolver import UrlResolver
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from typing import Optional, Union, List, TYPE_CHECKING
import sqlalchemy

Expand All @@ -14,7 +14,7 @@
from airflow.hooks import BaseHook
except ImportError:
# Required for Airflow 2.0
from airflow.hooks.base_hook import BaseHook # type: ignore
from airflow.hooks.base import BaseHook # type: ignore

if TYPE_CHECKING:
from boto3.session import ListObjectsResponseContentType, S3ClientTypeStub # noqa
Expand All @@ -41,7 +41,7 @@ def __init__(self,

def _get_boto3_session(self) -> boto3.session.Session:
if not self._boto3_session:
self._boto3_session = AwsHook(self.aws_conn_id).get_session()
self._boto3_session = AwsBaseHook(self.aws_conn_id).get_session()
return self._boto3_session

@property
Expand Down
8 changes: 1 addition & 7 deletions records_mover/airflow/hooks/sqlalchemy_db_hook.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import sqlalchemy as sa
from records_mover.db import create_sqlalchemy_url

try:
# Works with Airflow 1
from airflow.hooks import BaseHook
except ImportError:
# Required for Airflow 2.0
from airflow.hooks.base_hook import BaseHook # type: ignore
from airflow.hooks.base import BaseHook


class SqlAlchemyDbHook(BaseHook):
Expand Down
6 changes: 3 additions & 3 deletions records_mover/creds/creds_via_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

class CredsViaAirflow(BaseCreds):
def boto3_session(self, aws_creds_name: str) -> 'boto3.session.Session':
from airflow.contrib.hooks.aws_hook import AwsHook
aws_hook = AwsHook(aws_creds_name)
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
aws_hook = AwsBaseHook(aws_creds_name)
return aws_hook.get_session()

def db_facts(self, db_creds_name: str) -> DBFacts:
from airflow.hooks import BaseHook
from airflow.hooks.base import BaseHook
conn = BaseHook.get_connection(db_creds_name)
out: DBFacts = {}

Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ max-complexity = 15
[mypy]
mypy_path = types/stubs
warn_unused_ignores = True
disable_error_code = annotation-unchecked

[mypy-alembic.*]
ignore_missing_imports = True
Expand Down Expand Up @@ -84,4 +85,4 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-airflow.hooks.*]
ignore_missing_imports = True
ignore_missing_imports = True
15 changes: 4 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,13 @@ def initialize_options(self) -> None:
)

airflow_dependencies = [
# Minimum version here is needed to avoid syntax error in setup.py
# in 1.10.0
'apache-airflow>=1.10.1,<2'
'apache-airflow>=2',
'apache-airflow-providers-amazon',
'apache-airflow-providers-google',
]

db_dependencies = [
# Lower bound (>=1.3.18) is to improve package resolution performance
#
# Upper bound (<1.4) is to avoid 1.4 which has breaking changes and is
# incompatible with python-bigquery-sqlalchemy per
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/83
# Can lift this once records-mover itself is compatible and
# other packages have appropriate restrictions in place.
'sqlalchemy>=1.3.18,<1.4',
'sqlalchemy>=1.3.18',
]

smart_open_dependencies = [
Expand Down
62 changes: 34 additions & 28 deletions tests/component/records/schema/field/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
import numpy as np
import pandas as pd
import pytest


def with_nullable(nullable: bool, fn):
Expand All @@ -31,29 +32,30 @@ def check_dtype(field_type, constraints, expectation):
assert out.dtype == expectation


def test_to_pandas_dtype_integer_no_nullable():
class Test_to_pandas_dtype_integer_no_nullable:
expectations = {
(-100, 100): np.int8,
(0, 240): np.uint8,
(-10000, 10000): np.int16,
(500, 40000): np.uint16,
(-200000000, 200000000): np.int32,
(25, 4000000000): np.uint32,
(-9000000000000000000, 2000000000): np.int64,
(25, 10000000000000000000): np.uint64,
(-100, 100): pd.Int8Dtype(),
(0, 240): pd.UInt8Dtype(),
(-10000, 10000): pd.Int16Dtype(),
(500, 40000): pd.UInt16Dtype(),
(-200000000, 200000000): pd.Int32Dtype(),
(25, 4000000000): pd.UInt32Dtype(),
(-9000000000000000000, 2000000000): pd.Int64Dtype(),
(25, 10000000000000000000): pd.UInt64Dtype(),
(25, 1000000000000000000000000000): np.float128,
(None, None): np.int64,
(None, None): pd.Int64Dtype(),
}
for (min_, max_), expected_pandas_type in expectations.items():

@pytest.mark.parametrize("expectation", expectations.items())
def test_to_pandas_dtype_integer_no_nullable(self, expectation):
(min_, max_), expected_pandas_type = expectation
constraints = RecordsSchemaFieldIntegerConstraints(
required=True, unique=None, min_=min_, max_=max_
)
yield with_nullable(
False, check_dtype
), "integer", constraints, expected_pandas_type
with_nullable(False, check_dtype("integer", constraints, expected_pandas_type))


def test_to_pandas_dtype_integer_nullable():
class Test_to_pandas_dtype_integer_nullable:
expectations = {
(-100, 100): pd.Int8Dtype(),
(0, 240): pd.UInt8Dtype(),
Expand All @@ -66,16 +68,17 @@ def test_to_pandas_dtype_integer_nullable():
(25, 1000000000000000000000000000): np.float128,
(None, None): pd.Int64Dtype(),
}
for (min_, max_), expected_pandas_type in expectations.items():

@pytest.mark.parametrize("expectation", expectations.items())
def test_to_pandas_dtype_integer_nullable(self, expectation):
(min_, max_), expected_pandas_type = expectation
constraints = RecordsSchemaFieldIntegerConstraints(
required=True, unique=None, min_=min_, max_=max_
)
yield with_nullable(
True, check_dtype
), "integer", constraints, expected_pandas_type
with_nullable(True, check_dtype("integer", constraints, expected_pandas_type))


def test_to_pandas_dtype_decimal_float():
class Test_to_pandas_dtype_decimal_float():
expectations = {
(8, 4): np.float16,
(20, 10): np.float32,
Expand All @@ -84,10 +87,10 @@ def test_to_pandas_dtype_decimal_float():
(500, 250): np.float128,
(None, None): np.float64,
}
for (
fp_total_bits,
fp_significand_bits,
), expected_pandas_type in expectations.items():

@pytest.mark.parametrize("expectation", expectations.items())
def test_to_pandas_dtype_Tdecimal_float(self, expectation):
(fp_total_bits, fp_significand_bits), expected_pandas_type = expectation
constraints = RecordsSchemaFieldDecimalConstraints(
required=False,
unique=None,
Expand All @@ -96,10 +99,10 @@ def test_to_pandas_dtype_decimal_float():
fp_total_bits=fp_total_bits,
fp_significand_bits=fp_significand_bits,
)
yield check_dtype, "decimal", constraints, expected_pandas_type
check_dtype("decimal", constraints, expected_pandas_type)


def test_to_pandas_dtype_misc():
class Test_to_pandas_dtype_misc():
expectations = {
"boolean": np.bool_,
"string": np.object_,
Expand All @@ -108,8 +111,11 @@ def test_to_pandas_dtype_misc():
"datetimetz": "datetime64[ns, UTC]",
"time": np.object_,
}
for field_type, expected_pandas_type in expectations.items():
yield check_dtype, field_type, None, expected_pandas_type

@pytest.mark.parametrize("expectation", expectations.items())
def test_to_pandas_dtype_misc(self, expectation):
field_type, expected_pandas_type = expectation
check_dtype(field_type, None, expected_pandas_type)


def test_to_pandas_dtype_fixed_precision_():
Expand Down
12 changes: 5 additions & 7 deletions tests/integration/records/expected_column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
'redshift': [
'INTEGER', 'VARCHAR(3)', 'VARCHAR(3)', 'VARCHAR(1)', 'VARCHAR(1)',
'VARCHAR(3)', 'VARCHAR(111)', 'DATE', 'VARCHAR(8)',
'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMPTZ'
'TIMESTAMP', 'TIMESTAMPTZ'
],
'postgresql': [
'INTEGER', 'VARCHAR(3)', 'VARCHAR(3)', 'VARCHAR(1)', 'VARCHAR(1)',
'VARCHAR(3)', 'VARCHAR(111)', 'DATE', 'TIME WITHOUT TIME ZONE',
'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'
'VARCHAR(3)', 'VARCHAR(111)', 'DATE', 'TIME', 'TIMESTAMP', 'TIMESTAMP'
],
'bigquery': [
'INTEGER', 'VARCHAR(3)', 'VARCHAR(3)', 'VARCHAR(1)', 'VARCHAR(1)', 'VARCHAR(3)',
Expand All @@ -28,9 +27,8 @@

expected_df_loaded_database_column_types = {
'postgresql': [
'BIGINT', 'VARCHAR(12)', 'VARCHAR(12)', 'VARCHAR(4)', 'VARCHAR(4)',
'VARCHAR(12)', 'VARCHAR(444)', 'DATE', 'TIME WITHOUT TIME ZONE',
'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'
'BIGINT', 'VARCHAR(12)', 'VARCHAR(12)', 'VARCHAR(4)', 'VARCHAR(4)', 'VARCHAR(12)',
'VARCHAR(444)', 'DATE', 'TIME', 'TIMESTAMP', 'TIMESTAMP'
],
'mysql': [
'BIGINT(20)', 'VARCHAR(3)', 'VARCHAR(3)', 'VARCHAR(1)', 'VARCHAR(1)', 'VARCHAR(3)',
Expand All @@ -44,7 +42,7 @@
'redshift': [
'BIGINT', 'VARCHAR(12)', 'VARCHAR(12)', 'VARCHAR(4)', 'VARCHAR(4)',
'VARCHAR(12)', 'VARCHAR(444)', 'DATE', 'VARCHAR(8)',
'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMPTZ'
'TIMESTAMP', 'TIMESTAMPTZ'
],
'bigquery': [
'INTEGER', 'VARCHAR(12)', 'VARCHAR(12)', 'VARCHAR(4)', 'VARCHAR(4)',
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/records/single_db/numeric_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@
'uint64': 'NUMERIC(20, 0)',
'float16': 'REAL',
'float32': 'REAL',
'float64': 'DOUBLE PRECISION',
'float128': 'DOUBLE PRECISION', # Redshift doesn't support >float64
'float64': 'DOUBLE_PRECISION',
'float128': 'DOUBLE_PRECISION', # Redshift doesn't support >float64
'fixed_6_2': 'NUMERIC(6, 2)',
'fixed_38_9': 'NUMERIC(38, 9)',
'fixed_100_4': 'DOUBLE PRECISION' # Redshift doesn't support fixed precision > 38
'fixed_100_4': 'DOUBLE_PRECISION' # Redshift doesn't support fixed precision > 38
},
'vertica': {
'int8': 'INTEGER',
Expand Down Expand Up @@ -180,8 +180,8 @@
'uint64': 'NUMERIC(20, 0)',
'float16': 'REAL',
'float32': 'REAL',
'float64': 'DOUBLE PRECISION',
'float128': 'DOUBLE PRECISION', # Postgres doesn't support >float64
'float64': 'DOUBLE_PRECISION',
'float128': 'DOUBLE_PRECISION', # Postgres doesn't support >float64
'fixed_6_2': 'NUMERIC(6, 2)',
'fixed_38_9': 'NUMERIC(38, 9)',
'fixed_100_4': 'NUMERIC(100, 4)',
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/airflow/hooks/test_records_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def setUp(self):

@patch('records_mover.airflow.hooks.records_hook.UrlResolver')
@patch('records_mover.airflow.hooks.records_hook.Records')
@patch('records_mover.airflow.hooks.records_hook.AwsHook')
@patch('records_mover.airflow.hooks.records_hook.AwsBaseHook')
def test_get_conn(self,
mock_AwsHook,
mock_AwsBaseHook,
mock_Records,
mock_UrlResolver):
conn = self.records_hook.get_conn()
Expand All @@ -24,11 +24,11 @@ def test_get_conn(self,

@patch('records_mover.airflow.hooks.records_hook.UrlResolver')
@patch('records_mover.airflow.hooks.records_hook.Records')
@patch('records_mover.airflow.hooks.records_hook.AwsHook')
@patch('records_mover.airflow.hooks.records_hook.AwsBaseHook')
@patch('records_mover.airflow.hooks.records_hook.db_driver')
def test_get_conn_invalid_s3_url(self,
mock_db_driver,
mock_AwsHook,
mock_AwsBaseHook,
mock_Records,
mock_UrlResolver):
records_hook = RecordsHook(s3_temp_base_url='foo',
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/airflow/test_google_cloud_credentials_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from records_mover.airflow.hooks.google_cloud_credentials_hook import GoogleCloudCredentialsHook
from mock import Mock
import unittest
Expand All @@ -7,7 +7,7 @@
class TestGoogleCloudCredentialsHook(unittest.TestCase):
def test_get_conn(self):
mock_init = Mock('__init__')
GoogleCloudBaseHook.__init__ = mock_init
GoogleBaseHook.__init__ = mock_init
mock_init.return_value = None
hook = GoogleCloudCredentialsHook()
mock_get_credentials = Mock('get_credentials')
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/airflow/test_records_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@


class TestRecordsHook(unittest.TestCase):
@patch('records_mover.airflow.hooks.records_hook.AwsHook')
@patch('records_mover.airflow.hooks.records_hook.AwsBaseHook')
def test_validate_and_prepare_target_directory(self,
mock_AwsHook):
mock_AwsBaseHook):
target_url = 's3://bluelabs-fake-bucket'
mock_boto3_session = mock_AwsHook.return_value.get_session.return_value
mock_boto3_session = mock_AwsBaseHook.return_value.get_session.return_value
mock_s3 = mock_boto3_session.client.return_value
mock_s3.list_objects_v2.return_value.get.return_value =\
[{
Expand Down
Loading

0 comments on commit 62a31a3

Please sign in to comment.