diff --git a/metrics/bigfiles_high_water_mark b/metrics/bigfiles_high_water_mark index 1f3d8a7a1..d0378c4c1 100644 --- a/metrics/bigfiles_high_water_mark +++ b/metrics/bigfiles_high_water_mark @@ -1 +1 @@ -1028 +1049 diff --git a/metrics/coverage_high_water_mark b/metrics/coverage_high_water_mark index 9bb1e3ff0..5ca971149 100644 --- a/metrics/coverage_high_water_mark +++ b/metrics/coverage_high_water_mark @@ -1 +1 @@ -93.4700 \ No newline at end of file +93.5400 \ No newline at end of file diff --git a/metrics/flake8_high_water_mark b/metrics/flake8_high_water_mark index fb402ef6a..de8febe1c 100644 --- a/metrics/flake8_high_water_mark +++ b/metrics/flake8_high_water_mark @@ -1 +1 @@ -169 +168 diff --git a/metrics/mypy_high_water_mark b/metrics/mypy_high_water_mark index 6993dc564..d08d255b7 100644 --- a/metrics/mypy_high_water_mark +++ b/metrics/mypy_high_water_mark @@ -1 +1 @@ -92.3100 \ No newline at end of file +92.3300 \ No newline at end of file diff --git a/records_mover/mover_types.py b/records_mover/mover_types.py index a94136bae..983254b7b 100644 --- a/records_mover/mover_types.py +++ b/records_mover/mover_types.py @@ -11,8 +11,10 @@ # mypy way of validating we're covering all cases of an enum # # https://github.com/python/mypy/issues/6366#issuecomment-560369716 -def _assert_never(x: NoReturn) -> NoReturn: - assert False, "Unhandled type: {}".format(type(x).__name__) +def _assert_never(x: NoReturn, errmsg: Optional[str] = None) -> NoReturn: + if errmsg is None: + errmsg = "Unhandled type: {}".format(type(x).__name__) + assert False, errmsg # mypy-friendly way of doing a singleton object: diff --git a/records_mover/records/schema/field/__init__.py b/records_mover/records/schema/field/__init__.py index bb5348d91..297cd18d6 100644 --- a/records_mover/records/schema/field/__init__.py +++ b/records_mover/records/schema/field/__init__.py @@ -19,13 +19,13 @@ RecordsSchemaFieldIntegerConstraints, RecordsSchemaFieldDecimalConstraints) from .statistics import RecordsSchemaFieldStatistics -from .types import RECORDS_FIELD_TYPES +from .field_types import RECORDS_FIELD_TYPES if TYPE_CHECKING: from pandas import Series, Index from sqlalchemy import Column from sqlalchemy.types import TypeEngine from records_mover.db import DBDriver # noqa - from .types import FieldType + from .field_types import FieldType from mypy_extensions import TypedDict @@ -64,9 +64,9 @@ def __init__(self, def refine_from_series(self, series: 'Series', total_rows: int, - rows_sampled: int) -> None: + rows_sampled: int) -> 'RecordsSchemaField': from .pandas import refine_field_from_series - refine_field_from_series(self, series, total_rows, rows_sampled) + return refine_field_from_series(self, series, total_rows, rows_sampled) @staticmethod def is_more_specific_type(a: 'FieldType', b: 'FieldType') -> bool: @@ -77,6 +77,7 @@ def is_more_specific_type(a: 'FieldType', b: 'FieldType') -> bool: @staticmethod def python_type_to_field_type(specific_type: Type[Any]) -> Optional['FieldType']: import numpy as np + import pandas as pd # Note: records spec doesn't cover complex number types, so # np.complex_, complex64 and complex128 are not supported @@ -114,6 +115,10 @@ def python_type_to_field_type(specific_type: Type[Any]) -> Optional['FieldType'] datetime.date: 'date', datetime.time: 'time', + + datetime.datetime: 'datetime', + + pd.Timestamp: 'datetime', } if specific_type not in type_mapping: logger.warning(f"Please teach me how to map {specific_type} into records " @@ -318,3 +323,19 @@ def convert_datetime_to_datetimetz(self) -> 'RecordsSchemaField': constraints=self.constraints, statistics=self.statistics, representations=self.representations) + + def cast(self, field_type: 'FieldType') -> 'RecordsSchemaField': + if self.constraints is None: + constraints = None + else: + constraints = self.constraints.cast(field_type) + if self.statistics is None: + statistics = None + else: + statistics = self.statistics.cast(field_type) + field = RecordsSchemaField(name=self.name, + field_type=field_type, + constraints=constraints, + statistics=statistics, + representations=self.representations) + return field diff --git a/records_mover/records/schema/field/constraints/constraints.py b/records_mover/records/schema/field/constraints/constraints.py index 92cb6eda0..3ded475dc 100644 --- a/records_mover/records/schema/field/constraints/constraints.py +++ b/records_mover/records/schema/field/constraints/constraints.py @@ -1,5 +1,5 @@ import logging - +from records_mover.mover_types import _assert_never from typing import Optional, cast, TYPE_CHECKING from records_mover.utils.limits import (FLOAT16_SIGNIFICAND_BITS, FLOAT32_SIGNIFICAND_BITS, @@ -31,14 +31,14 @@ class FieldIntegerConstraintsDict(FieldConstraintsDict, total=False): min: str max: str - from ..types import FieldType # noqa + from ..field_types import FieldType # noqa logger = logging.getLogger(__name__) class RecordsSchemaFieldConstraints: - def __init__(self, required: bool, unique: Optional[bool]=None): + def __init__(self, required: bool, unique: Optional[bool] = None): """ :param required: If True, data must always be provided for this column in the origin representation; if False, a 'null' or @@ -108,6 +108,45 @@ def from_data(data: Optional['FieldConstraintsDict'], else: return RecordsSchemaFieldConstraints(required=required, unique=unique) + def cast(self, field_type: 'FieldType') -> 'RecordsSchemaFieldConstraints': + from .integer import RecordsSchemaFieldIntegerConstraints + from .decimal import RecordsSchemaFieldDecimalConstraints + from .string import RecordsSchemaFieldStringConstraints + required = self.required + unique = self.unique + constraints: RecordsSchemaFieldConstraints + if field_type == 'integer': + constraints =\ + RecordsSchemaFieldIntegerConstraints(required=required, + unique=unique, + min_=None, + max_=None) + elif field_type == 'string': + constraints =\ + RecordsSchemaFieldStringConstraints(required=required, + unique=unique, + max_length_bytes=None, + max_length_chars=None) + elif field_type == 'decimal': + constraints =\ + RecordsSchemaFieldDecimalConstraints(required=required, + unique=unique) + elif (field_type == 'boolean' or + field_type == 'date' or + field_type == 'time' or + field_type == 'timetz' or + field_type == 'datetime' or + field_type == 'datetimetz'): + constraints =\ + RecordsSchemaFieldConstraints(required=required, + unique=unique) + else: + _assert_never(field_type, + 'Teach me how to downcast constraints ' + f'for {field_type}') + + return constraints + @staticmethod def from_sqlalchemy_type(required: bool, unique: Optional[bool], diff --git a/records_mover/records/schema/field/field_types.py b/records_mover/records/schema/field/field_types.py new file mode 100644 index 000000000..43ef45b8d --- /dev/null +++ b/records_mover/records/schema/field/field_types.py @@ -0,0 +1,16 @@ +from typing_extensions import Literal +from typing_inspect import get_args +from typing import List + +FieldType = Literal['integer', + 'decimal', + 'string', + 'boolean', + 'date', + 'time', + 'timetz', + 'datetime', + 'datetimetz'] + +# Be sure to add new things below in FieldType, too +RECORDS_FIELD_TYPES: List[str] = list(get_args(FieldType)) # type: ignore diff --git a/records_mover/records/schema/field/numpy.py b/records_mover/records/schema/field/numpy.py index eba8a193a..6f64ae914 100644 --- a/records_mover/records/schema/field/numpy.py +++ b/records_mover/records/schema/field/numpy.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, TYPE_CHECKING from .constraints import RecordsSchemaFieldConstraints if TYPE_CHECKING: - from .types import FieldType # noqa + from .field_types import FieldType # noqa def details_from_numpy_dtype(dtype: numpy.dtype, diff --git a/records_mover/records/schema/field/pandas.py b/records_mover/records/schema/field/pandas.py index 770dfac15..5778ddfb1 100644 --- a/records_mover/records/schema/field/pandas.py +++ b/records_mover/records/schema/field/pandas.py @@ -44,7 +44,7 @@ def field_from_series(series: Series, def refine_field_from_series(field: 'RecordsSchemaField', series: Series, total_rows: int, - rows_sampled: int) -> None: + rows_sampled: int) -> 'RecordsSchemaField': from ..field import RecordsSchemaField # noqa # # if the series is full of object types that aren't numpy @@ -57,7 +57,7 @@ def refine_field_from_series(field: 'RecordsSchemaField', field_type = field.python_type_to_field_type(unique_python_type) if field_type is not None: if RecordsSchemaField.is_more_specific_type(field_type, field.field_type): - field.field_type = field_type + field = field.cast(field_type) if field.field_type == 'string': max_column_length = series.astype('str').map(len).max() @@ -70,7 +70,8 @@ def refine_field_from_series(field: 'RecordsSchemaField', if field.statistics is None: field.statistics = statistics elif not isinstance(field.statistics, RecordsSchemaFieldStringStatistics): - raise SyntaxError("Did not expect to see existing statistics " - f"for string type: {field.statistics}") + raise ValueError("Did not expect to see existing statistics " + f"for string type: {field.statistics}") else: field.statistics.merge(statistics) + return field diff --git a/records_mover/records/schema/field/sqlalchemy.py b/records_mover/records/schema/field/sqlalchemy.py index c2d4e7e36..857648131 100644 --- a/records_mover/records/schema/field/sqlalchemy.py +++ b/records_mover/records/schema/field/sqlalchemy.py @@ -13,7 +13,7 @@ from ....db import DBDriver # noqa from ..field import RecordsSchemaField # noqa from ..schema import RecordsSchema # noqa - from .types import FieldType # noqa + from .field_types import FieldType # noqa logger = logging.getLogger(__name__) @@ -127,6 +127,10 @@ def field_from_sqlalchemy_column(column: Column, def field_to_sqlalchemy_type(field: 'RecordsSchemaField', driver: 'DBDriver') -> sqlalchemy.types.TypeEngine: if field.field_type == 'integer': + if field.constraints and\ + not isinstance(field.constraints, RecordsSchemaFieldIntegerConstraints): + raise ValueError(f"Incorrect constraint type in {field.name}: {field.constraints}") + int_constraints =\ cast(Optional[RecordsSchemaFieldIntegerConstraints], field.constraints) min_: Optional[int] = None @@ -162,11 +166,11 @@ def field_to_sqlalchemy_type(field: 'RecordsSchemaField', elif field.field_type == 'string': if field.constraints and\ not isinstance(field.constraints, RecordsSchemaFieldStringConstraints): - raise SyntaxError(f"Incorrect constraint type: {field.constraints}") + raise ValueError(f"Incorrect constraint type in {field.name}: {field.constraints}") if field.statistics and\ not isinstance(field.statistics, RecordsSchemaFieldStringStatistics): - raise SyntaxError(f"Incorrect statistics type: {field.statistics}") + raise ValueError(f"Incorrect statistics type in {field.name}: {field.statistics}") string_constraints =\ cast(Optional[RecordsSchemaFieldStringConstraints], field.constraints) diff --git a/records_mover/records/schema/field/statistics.py b/records_mover/records/schema/field/statistics.py index 262f83078..f57f9e2d3 100644 --- a/records_mover/records/schema/field/statistics.py +++ b/records_mover/records/schema/field/statistics.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: from mypy_extensions import TypedDict - from .types import FieldType # noqa + from .field_types import FieldType # noqa class FieldStatisticsDict(TypedDict): rows_sampled: int @@ -48,6 +48,10 @@ def from_data(data: Optional[Union['FieldStatisticsDict', 'StringFieldStatistics return RecordsSchemaFieldStatistics(rows_sampled=rows_sampled, total_rows=total_rows) + def cast(self, field_type: 'FieldType') -> Optional['RecordsSchemaFieldStatistics']: + # only string provides statistics at this point + return None + def __str__(self) -> str: return f"{type(self)}({self.to_data()})" @@ -74,3 +78,9 @@ def to_data(self) -> 'StringFieldStatisticsDict': def merge(self, other: 'RecordsSchemaFieldStringStatistics') -> None: raise NotImplementedError + + def cast(self, field_type: 'FieldType') -> Optional['RecordsSchemaFieldStatistics']: + if field_type == 'string': + return self + else: + return super().cast(field_type) diff --git a/records_mover/records/schema/field/types.py b/records_mover/records/schema/field/types.py deleted file mode 100644 index 24c51adbf..000000000 --- a/records_mover/records/schema/field/types.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import TYPE_CHECKING -# Be sure to add new things below in FieldType, too -RECORDS_FIELD_TYPES = { - 'integer', - 'decimal', - 'string', - 'boolean', - 'date', - 'time', - 'timetz', - 'datetime', - 'datetimetz' -} - -if TYPE_CHECKING: - from typing_extensions import Literal # noqa - FieldType = Literal['integer', - 'decimal', - 'string', - 'boolean', - 'date', - 'time', - 'timetz', - 'datetime', - 'datetimetz'] diff --git a/records_mover/records/schema/schema/__init__.py b/records_mover/records/schema/schema/__init__.py index a5d8f75cf..990477055 100644 --- a/records_mover/records/schema/schema/__init__.py +++ b/records_mover/records/schema/schema/__init__.py @@ -162,14 +162,14 @@ def from_fileobjs(fileobjs: List[IO[bytes]], schema = RecordsSchema.from_dataframe(df, processing_instructions, include_index=False) - schema.refine_from_dataframe(df, - processing_instructions=processing_instructions) + schema = schema.refine_from_dataframe(df, + processing_instructions=processing_instructions) return schema def refine_from_dataframe(self, df: 'DataFrame', processing_instructions: - ProcessingInstructions = ProcessingInstructions()) -> None: + ProcessingInstructions = ProcessingInstructions()) -> 'RecordsSchema': """ Adjust records schema based on facts found from a dataframe. """ diff --git a/records_mover/records/schema/schema/pandas.py b/records_mover/records/schema/schema/pandas.py index e2d19fdeb..1ff28f9b3 100644 --- a/records_mover/records/schema/schema/pandas.py +++ b/records_mover/records/schema/schema/pandas.py @@ -35,7 +35,10 @@ def schema_from_dataframe(df: DataFrame, def refine_schema_from_dataframe(records_schema: 'RecordsSchema', df: DataFrame, processing_instructions: - ProcessingInstructions = ProcessingInstructions()) -> None: + ProcessingInstructions = ProcessingInstructions()) ->\ + 'RecordsSchema': + from records_mover.records.schema import RecordsSchema + max_sample_size = processing_instructions.max_inference_rows total_rows = len(df.index) if max_sample_size is not None and max_sample_size < total_rows: @@ -44,8 +47,11 @@ def refine_schema_from_dataframe(records_schema: 'RecordsSchema', sampled_df = df rows_sampled = len(sampled_df.index) - for field in records_schema.fields: - series = sampled_df[field.name] - field.refine_from_series(series, + fields = [ + field.refine_from_series(sampled_df[field.name], total_rows=total_rows, rows_sampled=rows_sampled) + for field in records_schema.fields + ] + return RecordsSchema(fields=fields, + known_representations=records_schema.known_representations) diff --git a/records_mover/records/sources/dataframes.py b/records_mover/records/sources/dataframes.py index 6f772f26f..42ae0abfd 100644 --- a/records_mover/records/sources/dataframes.py +++ b/records_mover/records/sources/dataframes.py @@ -113,7 +113,7 @@ def schema_from_df(self, df: 'DataFrame', # Otherwise, gather information to create an efficient # schema on the target of the move. # - records_schema.refine_from_dataframe(df, processing_instructions) + records_schema = records_schema.refine_from_dataframe(df, processing_instructions) return records_schema diff --git a/setup.cfg b/setup.cfg index 13ca93f1a..17c0ef619 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,9 @@ max-complexity = 15 mypy_path = types/stubs warn_unused_ignores = True +[mypy-alembic.*] +ignore_missing_imports = True + # https://github.com/pandas-dev/pandas/issues/26766 # https://github.com/pandas-dev/pandas/issues/26792 # https://github.com/pandas-dev/pandas/issues/28142 diff --git a/tests/component/records/schema/field/test_pandas.py b/tests/component/records/schema/field/test_pandas.py new file mode 100644 index 000000000..e6dfc3823 --- /dev/null +++ b/tests/component/records/schema/field/test_pandas.py @@ -0,0 +1,109 @@ +import unittest +import datetime +import pytz +from records_mover.records.schema.field.constraints import ( + RecordsSchemaFieldConstraints, + RecordsSchemaFieldStringConstraints, + RecordsSchemaFieldIntegerConstraints, + RecordsSchemaFieldDecimalConstraints, +) +from records_mover.records.schema.field.statistics import ( + RecordsSchemaFieldStringStatistics, +) +from records_mover.records.schema.field.representation import RecordsSchemaPandasFieldRepresentation +from records_mover.records.schema.field import RecordsSchemaField +from records_mover.records.schema.field.pandas import refine_field_from_series +from records_mover.records.schema.field.field_types import RECORDS_FIELD_TYPES +import pandas as pd + + +class TestPandas(unittest.TestCase): + def test_refine_field_from_series_more_specific(self) -> None: + # This test is designed to break when a new field type is + # introduced, so you can add new expectations and make sure + # the code handles the new type! + + fields = { + 'integer': { + 'series': pd.Series([30, 35, 40]), + 'constraints_type': RecordsSchemaFieldIntegerConstraints, + 'statistics_type': type(None), + }, + 'decimal': { + 'series': pd.Series([30.0, 35.1, 40.2]), + 'constraints_type': RecordsSchemaFieldDecimalConstraints, + 'statistics_type': type(None), + }, + 'string': { + 'series': pd.Series(['a', 'b', 'c']), + 'constraints_type': RecordsSchemaFieldStringConstraints, + 'statistics_type': RecordsSchemaFieldStringStatistics, + }, + 'boolean': { + 'series': pd.Series([True, True, False]), + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + }, + 'date': { + 'series': pd.Series([datetime.date(2020, 1, 1)]), + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + }, + 'time': { + 'series': pd.Series([datetime.time(hour=12, minute=0, second=0)]), + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + }, + 'timetz': { + 'series': pd.Series([datetime.time(hour=12, minute=0, second=0, + tzinfo=pytz.timezone('US/Eastern'))]), + # refine_field_from_series() is not smart enough to + # distinguish whether the time objects inside it all + # have timezones or not. + 'expected_field_type': 'time', + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + }, + 'datetime': { + 'series': pd.Series([datetime.datetime(2020, 1, 1, hour=12)]), + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + }, + 'datetimetz': { + 'series': pd.Series([datetime.datetime(2020, 1, 1, hour=12, + tzinfo=pytz.timezone('US/Eastern'))]), + # refine_field_from_series() is not smart enough to + # distinguish whether the datetime objects inside it all + # have timezones or not. + 'expected_field_type': 'datetime', + 'constraints_type': RecordsSchemaFieldConstraints, + 'statistics_type': type(None), + } + } + for field_type in RECORDS_FIELD_TYPES: + constraints = RecordsSchemaFieldStringConstraints(required=True, + unique=False, + max_length_bytes=255, + max_length_chars=255) + pandas_representation = RecordsSchemaPandasFieldRepresentation(pd_df_dtype={}, + pd_df_ftype=None, + pd_df_coltype='series') + field = RecordsSchemaField(name='testfield', + field_type='string', + constraints=constraints, + statistics=None, + representations={'pandas': pandas_representation}) + series = fields[field_type]['series'] + returned_field = refine_field_from_series(field, + series, + total_rows=10, + rows_sampled=10) + if 'expected_field_type' in fields[field_type]: + self.assertEquals(returned_field.field_type, + fields[field_type]['expected_field_type']) + else: + self.assertEquals(returned_field.field_type, field_type) + self.assertEquals(type(returned_field.constraints), + fields[field_type]['constraints_type']) + self.assertEquals(type(returned_field.statistics), + fields[field_type]['statistics_type']) diff --git a/tests/component/records/test_dataframe_schema_sql_creation.py b/tests/component/records/test_dataframe_schema_sql_creation.py new file mode 100644 index 000000000..0ba5a2bc0 --- /dev/null +++ b/tests/component/records/test_dataframe_schema_sql_creation.py @@ -0,0 +1,40 @@ +from records_mover.records.sources.dataframes import DataframesRecordsSource +from records_mover.db.redshift.redshift_db_driver import RedshiftDBDriver +from records_mover.records.processing_instructions import ProcessingInstructions +from sqlalchemy_redshift.dialect import RedshiftDialect +import unittest +from unittest.mock import Mock +from pandas import DataFrame + + +class TestDataframeSchemaSqlCreation(unittest.TestCase): + def test_dataframe_to_int64_and_back_to_object_produces_int_columns(self) -> None: + # This reproduces a situation found when a user worked around + # a separate historical Records Mover limitation by doing an + # unusual cast on their dataframe...and then hit a separate + # limitation: + # + # https://github.com/bluelabsio/records-mover/pull/103 + data = {'Population': [11190846, 1303171035, 207847528]} + df = DataFrame(data, columns=['Population']) + + df['Population'] = df['Population'].astype("Int64") + df['Population'] = df['Population'].astype("object") + + source = DataframesRecordsSource(dfs=[df]) + processing_instructions = ProcessingInstructions() + schema = source.initial_records_schema(processing_instructions) + dialect = RedshiftDialect() + mock_engine = Mock(name='engine') + mock_engine.dialect = dialect + driver = RedshiftDBDriver(db=mock_engine) + schema_sql = schema.to_schema_sql(driver=driver, + schema_name='my_schema_name', + table_name='my_table_name') + expected_schema_sql = """ +CREATE TABLE my_schema_name.my_table_name ( +\t"Population" INTEGER +) + +""" + self.assertEqual(schema_sql, expected_schema_sql) diff --git a/tests/integration/records/single_db/test_records_save_df.py b/tests/integration/records/single_db/test_records_save_df.py index 1b935ab27..1242d08b8 100644 --- a/tests/integration/records/single_db/test_records_save_df.py +++ b/tests/integration/records/single_db/test_records_save_df.py @@ -42,7 +42,7 @@ def save_and_verify(self, records_format, processing_instructions=None) -> None: records_schema = RecordsSchema.from_dataframe(df, processing_instructions, include_index=False) - records_schema.refine_from_dataframe(df, processing_instructions) + records_schema = records_schema.refine_from_dataframe(df, processing_instructions) with tempfile.TemporaryDirectory(prefix='test_records_save_df') as tempdir: output_url = pathlib.Path(tempdir).resolve().as_uri() + '/' diff --git a/tests/unit/records/schema/field/sqlalchemy/test_field_to_sqlalchemy_type.py b/tests/unit/records/schema/field/sqlalchemy/test_field_to_sqlalchemy_type.py index 852b1d929..5cb31e2f1 100644 --- a/tests/unit/records/schema/field/sqlalchemy/test_field_to_sqlalchemy_type.py +++ b/tests/unit/records/schema/field/sqlalchemy/test_field_to_sqlalchemy_type.py @@ -2,19 +2,26 @@ from mock import Mock, patch import sqlalchemy from records_mover.records.schema.field.sqlalchemy import field_to_sqlalchemy_type -from records_mover.records.schema.field.constraints import RecordsSchemaFieldStringConstraints +from records_mover.records.schema.field.constraints import (RecordsSchemaFieldStringConstraints, + RecordsSchemaFieldIntegerConstraints) from records_mover.records.schema.field.statistics import RecordsSchemaFieldStringStatistics class TestSqlAlchemyFieldToSqlalchemyType(unittest.TestCase): def test_integer(self): mock_field = Mock(name='field') + mock_driver = Mock(name='driver') mock_field.field_type = 'integer' - mock_int_constraints = mock_field.constraints - mock_min_ = mock_int_constraints.min_ - mock_max_ = mock_int_constraints.max_ + mock_int_constraints = Mock(name='int_constraints', + spec=RecordsSchemaFieldIntegerConstraints) + + mock_field.constraints = mock_int_constraints + mock_min_ = Mock(name='min_') + mock_int_constraints.min_ = mock_min_ + mock_max_ = Mock(name='max_') + mock_int_constraints.max_ = mock_max_ out = field_to_sqlalchemy_type(field=mock_field, driver=mock_driver) diff --git a/tests/unit/records/schema/field/test_pandas.py b/tests/unit/records/schema/field/test_pandas.py index b403bf57c..a92eb1756 100644 --- a/tests/unit/records/schema/field/test_pandas.py +++ b/tests/unit/records/schema/field/test_pandas.py @@ -1,9 +1,7 @@ import unittest from mock import Mock, patch -from pandas import Series from records_mover.records.schema.field.pandas import (field_from_index, - field_from_series, - refine_field_from_series) + field_from_series) class TestPandas(unittest.TestCase): @@ -56,61 +54,3 @@ def test_field_from_series(self, representations=mock_representations, statistics=None) self.assertEqual(out, mock_RecordsSchemaField.return_value) - - @patch('records_mover.records.schema.field.RecordsSchemaField') - @patch('records_mover.records.schema.field.pandas.RecordsSchemaFieldStringStatistics') - def test_refine_field_from_series_all_string(self, - mock_RecordsSchemaFieldStringStatistics, - mock_RecordsSchemaField): - mock_field = Mock(name='field') - mock_field.statistics = None - mock_field.field_type = 'string' - series = Series(["mumble", "foo", "b"]) - mock_total_rows = Mock(name='total_rows') - mock_rows_sampled = Mock(name='rows_sampled') - - mock_field.python_type_to_field_type.return_value = 'string' - - mock_RecordsSchemaField.is_more_specific_type.return_value = True - - mock_statistics = mock_RecordsSchemaFieldStringStatistics.return_value - - refine_field_from_series(mock_field, - series=series, - total_rows=mock_total_rows, - rows_sampled=mock_rows_sampled) - mock_RecordsSchemaFieldStringStatistics.\ - assert_called_with(max_length_bytes=None, - max_length_chars=6, - rows_sampled=mock_rows_sampled, - total_rows=mock_total_rows) - self.assertEqual(mock_field.statistics, mock_statistics) - self.assertEqual(mock_field.field_type, 'string') - - @patch('records_mover.records.schema.field.RecordsSchemaField') - @patch('records_mover.records.schema.field.pandas.RecordsSchemaFieldStringStatistics') - def test_refine_field_from_series_diverse(self, - mock_RecordsSchemaFieldStringStatistics, - mock_RecordsSchemaField): - mock_field = Mock(name='field') - mock_field.statistics = None - mock_field.field_type = 'string' - series = Series(["mumble", "foo", 1]) - mock_total_rows = Mock(name='total_rows') - mock_rows_sampled = Mock(name='rows_sampled') - - mock_RecordsSchemaField.is_more_specific_type.return_value = True - - mock_statistics = mock_RecordsSchemaFieldStringStatistics.return_value - - refine_field_from_series(mock_field, - series=series, - total_rows=mock_total_rows, - rows_sampled=mock_rows_sampled) - mock_RecordsSchemaFieldStringStatistics.\ - assert_called_with(max_length_bytes=None, - max_length_chars=6, - rows_sampled=mock_rows_sampled, - total_rows=mock_total_rows,) - self.assertEqual(mock_field.statistics, mock_statistics) - self.assertEqual(mock_field.field_type, 'string') diff --git a/tests/unit/records/schema/test_records_schema.py b/tests/unit/records/schema/test_records_schema.py index 1603cdb19..104254811 100644 --- a/tests/unit/records/schema/test_records_schema.py +++ b/tests/unit/records/schema/test_records_schema.py @@ -99,7 +99,9 @@ def test_from_fileobjs(self, {'Country': 'Brazil', 'Capital': 'Brasília', 'Population': 207847528}, ] self.assertEqual(actual_cleaned_up_df_data, expected_cleaned_up_df_data) - self.assertEqual(out, mock_RecordsSchema.from_dataframe.return_value) + self.assertEqual(out, + mock_RecordsSchema.from_dataframe.return_value. + refine_from_dataframe.return_value) @patch('records_mover.records.schema.schema.RecordsSchema') @patch('records_mover.records.delimited.stream_csv') @@ -138,7 +140,9 @@ def test_from_fileobjs_no_max_inference_rows(self, {'Country': 'Brazil', 'Capital': 'Brasília', 'Population': 207847528}, ] self.assertEqual(actual_cleaned_up_df_data, expected_cleaned_up_df_data) - self.assertEqual(out, mock_RecordsSchema.from_dataframe.return_value) + self.assertEqual(out, + mock_RecordsSchema.from_dataframe.return_value. + refine_from_dataframe.return_value) @patch('records_mover.records.schema.schema.pandas.refine_schema_from_dataframe') def test_refine_from_dataframe(self, diff --git a/tests/unit/records/sources/test_dataframes.py b/tests/unit/records/sources/test_dataframes.py index 8c0e2cdd4..a4592cadd 100644 --- a/tests/unit/records/sources/test_dataframes.py +++ b/tests/unit/records/sources/test_dataframes.py @@ -49,7 +49,8 @@ def generate_filename(prefix): return f"{prefix}.csv" mock_target_records_format.generate_filename = generate_filename - mock_target_records_schema = mock_RecordsSchema.from_dataframe.return_value + mock_target_records_schema = mock_RecordsSchema.from_dataframe.return_value.\ + refine_from_dataframe.return_value mock_purge_unnamed_unused_columns.side_effect = lambda a: a with dataframe_records_source.\ to_fileobjs_source(records_format_if_possible=mock_target_records_format, @@ -124,7 +125,8 @@ def generate_filename(prefix): mock_target_records_format.generate_filename = generate_filename - mock_target_records_schema = mock_RecordsSchema.from_dataframe.return_value + mock_target_records_schema = mock_RecordsSchema.from_dataframe.return_value.\ + refine_from_dataframe.return_value mock_output_file = mock_NamedTemporaryFile.return_value.__enter__.return_value mock_output_filename = mock_output_file.name