From 3fc894cfc8cbf24a549b4d87bbb5b961ef372141 Mon Sep 17 00:00:00 2001 From: Chris Wegrzyn Date: Wed, 20 Jan 2021 17:02:49 -0500 Subject: [PATCH] Prefer pandas nullable integers for int fields (#159) --- .gitignore | 3 +- metrics/coverage_high_water_mark | 2 +- metrics/mypy_high_water_mark | 2 +- records_mover/db/mysql/mysql_db_driver.py | 53 +++---- records_mover/records/pandas/prep_for_csv.py | 4 + .../records/schema/field/__init__.py | 56 ++++---- .../schema/field/constraints/constraints.py | 3 + records_mover/records/schema/field/pandas.py | 53 ++++++- records_mover/utils/limits.py | 33 +++++ setup.cfg | 3 + .../records/schema/field/test_dtype.py | 132 ++++++++++++++++++ .../multi_db/test_records_table2table.py | 5 +- .../records/single_db/base_records_test.py | 5 +- tests/integration/resources/_schema.json | 2 +- tests/unit/records/schema/field/test_field.py | 112 --------------- tests/unit/records/sources/test_fileobjs.py | 17 +++ 16 files changed, 304 insertions(+), 181 deletions(-) create mode 100644 tests/component/records/schema/field/test_dtype.py diff --git a/.gitignore b/.gitignore index a8c7a9e9b..b220f30d0 100644 --- a/.gitignore +++ b/.gitignore @@ -37,8 +37,7 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ -.coverage -.coverage.* +.coverage* .cache nosetests.xml coverage.xml diff --git a/metrics/coverage_high_water_mark b/metrics/coverage_high_water_mark index 440d931ef..5f2eabeef 100644 --- a/metrics/coverage_high_water_mark +++ b/metrics/coverage_high_water_mark @@ -1 +1 @@ -93.6100 +93.6700 diff --git a/metrics/mypy_high_water_mark b/metrics/mypy_high_water_mark index b21df6ea0..056053212 100644 --- a/metrics/mypy_high_water_mark +++ b/metrics/mypy_high_water_mark @@ -1 +1 @@ -92.5000 +92.1900 \ No newline at end of file diff --git a/records_mover/db/mysql/mysql_db_driver.py b/records_mover/db/mysql/mysql_db_driver.py index 6d7615107..f50822166 100644 --- a/records_mover/db/mysql/mysql_db_driver.py +++ b/records_mover/db/mysql/mysql_db_driver.py @@ -1,16 +1,7 @@ import sqlalchemy import sqlalchemy.dialects.mysql import logging -from ...utils.limits import (INT8_MIN, INT8_MAX, - UINT8_MIN, UINT8_MAX, - INT16_MIN, INT16_MAX, - UINT16_MIN, UINT16_MAX, - INT24_MIN, INT24_MAX, - UINT24_MIN, UINT24_MAX, - INT32_MIN, INT32_MAX, - UINT32_MIN, UINT32_MAX, - INT64_MIN, INT64_MAX, - UINT64_MIN, UINT64_MAX, +from ...utils.limits import (IntegerType, FLOAT32_SIGNIFICAND_BITS, FLOAT64_SIGNIFICAND_BITS, num_digits) @@ -48,29 +39,29 @@ def integer_limits(self, Optional[Tuple[int, int]]: if isinstance(type_, sqlalchemy.dialects.mysql.TINYINT): if type_.unsigned: - return (UINT8_MIN, UINT8_MAX) + return IntegerType.UINT8.range else: - return (INT8_MIN, INT8_MAX) + return IntegerType.INT8.range elif isinstance(type_, sqlalchemy.dialects.mysql.SMALLINT): if type_.unsigned: - return (UINT16_MIN, UINT16_MAX) + return IntegerType.UINT16.range else: - return (INT16_MIN, INT16_MAX) + return IntegerType.INT16.range elif isinstance(type_, sqlalchemy.dialects.mysql.MEDIUMINT): if type_.unsigned: - return (UINT24_MIN, UINT24_MAX) + return IntegerType.UINT24.range else: - return (INT24_MIN, INT24_MAX) + return IntegerType.INT24.range elif isinstance(type_, sqlalchemy.dialects.mysql.INTEGER): if type_.unsigned: - return (UINT32_MIN, UINT32_MAX) + return IntegerType.UINT32.range else: - return (INT32_MIN, INT32_MAX) + return IntegerType.INT32.range elif isinstance(type_, sqlalchemy.dialects.mysql.BIGINT): if type_.unsigned: - return (UINT64_MIN, UINT64_MAX) + return IntegerType.UINT64.range else: - return (INT64_MIN, INT64_MAX) + return IntegerType.INT64.range return super().integer_limits(type_) def fp_constraints(self, @@ -88,26 +79,26 @@ def type_for_integer(self, """Find correct integral column type to fit the given min and max integer values""" if min_value is not None and max_value is not None: - pass - if min_value >= INT8_MIN and max_value <= INT8_MAX: + int_type = IntegerType.smallest_cover_for(min_value, max_value) + if int_type == IntegerType.INT8: return sqlalchemy.dialects.mysql.TINYINT() - elif min_value >= UINT8_MIN and max_value <= UINT8_MAX: + elif int_type == IntegerType.UINT8: return sqlalchemy.dialects.mysql.TINYINT(unsigned=True) - elif min_value >= INT16_MIN and max_value <= INT16_MAX: + elif int_type == IntegerType.INT16: return sqlalchemy.sql.sqltypes.SMALLINT() - elif min_value >= UINT16_MIN and max_value <= UINT16_MAX: + elif int_type == IntegerType.UINT16: return sqlalchemy.dialects.mysql.SMALLINT(unsigned=True) - elif min_value >= INT24_MIN and max_value <= INT24_MAX: + elif int_type == IntegerType.INT24: return sqlalchemy.dialects.mysql.MEDIUMINT() - elif min_value >= UINT24_MIN and max_value <= UINT24_MAX: + elif int_type == IntegerType.UINT24: return sqlalchemy.dialects.mysql.MEDIUMINT(unsigned=True) - elif min_value >= INT32_MIN and max_value <= INT32_MAX: + elif int_type == IntegerType.INT32: return sqlalchemy.sql.sqltypes.INTEGER() - elif min_value >= UINT32_MIN and max_value <= UINT32_MAX: + elif int_type == IntegerType.UINT32: return sqlalchemy.dialects.mysql.INTEGER(unsigned=True) - elif min_value >= INT64_MIN and max_value <= INT64_MAX: + elif int_type == IntegerType.INT64: return sqlalchemy.sql.sqltypes.BIGINT() - elif min_value >= UINT64_MIN and max_value <= UINT64_MAX: + elif int_type == IntegerType.UINT64: return sqlalchemy.dialects.mysql.BIGINT(unsigned=True) else: num_digits_min = num_digits(min_value) diff --git a/records_mover/records/pandas/prep_for_csv.py b/records_mover/records/pandas/prep_for_csv.py index eb24b9bed..4839f164f 100644 --- a/records_mover/records/pandas/prep_for_csv.py +++ b/records_mover/records/pandas/prep_for_csv.py @@ -26,6 +26,8 @@ def _convert_series_or_index(series_or_index: T, isinstance(series_or_index[0], datetime.date)): logger.info(f"Converting {series_or_index.name} from np.datetime64 to " "string in CSV's format") + logger.debug("Dtype is %s, first element type %s", series_or_index.dtype, + type(series_or_index[0])) hint_date_format = records_format.hints['dateformat'] assert isinstance(hint_date_format, str) pandas_date_format = python_date_format_from_hints.get(hint_date_format) @@ -49,6 +51,8 @@ def _convert_series_or_index(series_or_index: T, else: logger.info(f"Converting {series_or_index.name} from np.datetime64 to string " "in CSV's format") + logger.debug("Dtype is %s, first element type %s", series_or_index.dtype, + type(series_or_index[0])) hint_time_format = records_format.hints['timeonlyformat'] assert isinstance(hint_time_format, str) pandas_time_format = python_time_format_from_hints.get(hint_time_format) diff --git a/records_mover/records/schema/field/__init__.py b/records_mover/records/schema/field/__init__.py index 7feac8312..f4d4e2906 100644 --- a/records_mover/records/schema/field/__init__.py +++ b/records_mover/records/schema/field/__init__.py @@ -1,16 +1,8 @@ import datetime from ...processing_instructions import ProcessingInstructions import logging -from typing import Optional, Dict, Any, Type, cast, Union, TYPE_CHECKING -from ....utils.limits import (INT8_MIN, INT8_MAX, - UINT8_MIN, UINT8_MAX, - INT16_MIN, INT16_MAX, - UINT16_MIN, UINT16_MAX, - INT32_MIN, INT32_MAX, - UINT32_MIN, UINT32_MAX, - INT64_MIN, INT64_MAX, - UINT64_MIN, UINT64_MAX, - FLOAT16_SIGNIFICAND_BITS, +from typing import Optional, Dict, Any, Type, cast, TYPE_CHECKING +from ....utils.limits import (FLOAT16_SIGNIFICAND_BITS, FLOAT32_SIGNIFICAND_BITS, FLOAT64_SIGNIFICAND_BITS, FLOAT80_SIGNIFICAND_BITS) @@ -26,6 +18,7 @@ from sqlalchemy.types import TypeEngine from records_mover.db import DBDriver # noqa from .field_types import FieldType + from .pandas import Dtype from mypy_extensions import TypedDict @@ -192,40 +185,41 @@ def components_to_time_str(df: pd.DataFrame) -> datetime.time: return datetime.time(hour=df['hours'], minute=df['minutes'], second=df['seconds']) + logger.debug("Applying pd.Timedelta logic on series for %s", self.name) out = series.dt.components.apply(axis=1, func=components_to_time_str) return out - return series.astype(self.to_numpy_dtype()) + target_type = self.to_pandas_dtype() + logger.debug("Casting field %s from type %r to type %s", self.name, series.dtype, + target_type) + return series.astype(target_type) - def to_numpy_dtype(self) -> Union[Type[Any], str]: + def to_pandas_dtype(self) -> 'Dtype': import numpy as np + import pandas as pd + from .pandas import supports_nullable_ints, integer_type_for_range + + has_extension_types = supports_nullable_ints() if self.field_type == 'integer': int_constraints =\ cast(Optional[RecordsSchemaFieldIntegerConstraints], self.constraints) min_: Optional[int] = None max_: Optional[int] = None + required = False if int_constraints: min_ = int_constraints.min_ max_ = int_constraints.max_ + required = int_constraints.required + + if not required and not has_extension_types: + logger.warning(f"Dataframe field {self.name} is nullable, but using pandas " + f"{pd.__version__} which does not support nullable integer type") if min_ is not None and max_ is not None: - if min_ >= INT8_MIN and max_ <= INT8_MAX: - return np.int8 - elif min_ >= UINT8_MIN and max_ <= UINT8_MAX: - return np.uint8 - elif min_ >= INT16_MIN and max_ <= INT16_MAX: - return np.int16 - elif min_ >= UINT16_MIN and max_ <= UINT16_MAX: - return np.uint16 - elif min_ >= INT32_MIN and max_ <= INT32_MAX: - return np.int32 - elif min_ >= UINT32_MIN and max_ <= UINT32_MAX: - return np.uint32 - elif min_ >= INT64_MIN and max_ <= INT64_MAX: - return np.int64 - elif min_ >= UINT64_MIN and max_ <= UINT64_MAX: - return np.uint64 + dtype = integer_type_for_range(min_, max_, has_extension_types) + if dtype: + return dtype else: logger.warning("Asked for a type larger than int64 in dataframe " f"field '{self.name}' - providing float128, but " @@ -235,8 +229,10 @@ def to_numpy_dtype(self) -> Union[Type[Any], str]: else: logger.warning(f"No integer constraints provided for field '{self.name}'; " "using int64") - return np.int64 - # return driver.type_for_integer(min_=min_, max_=max_) + if has_extension_types: + return pd.Int64Dtype() + else: + return np.int64 elif self.field_type == 'decimal': decimal_constraints =\ cast(Optional[RecordsSchemaFieldDecimalConstraints], self.constraints) diff --git a/records_mover/records/schema/field/constraints/constraints.py b/records_mover/records/schema/field/constraints/constraints.py index 3ded475dc..92c396434 100644 --- a/records_mover/records/schema/field/constraints/constraints.py +++ b/records_mover/records/schema/field/constraints/constraints.py @@ -204,3 +204,6 @@ def from_numpy_dtype(dtype: 'np.dtype', def __str__(self) -> str: return f"{type(self).__name__}({self.to_data()})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/records_mover/records/schema/field/pandas.py b/records_mover/records/schema/field/pandas.py index 5778ddfb1..ecc5b1557 100644 --- a/records_mover/records/schema/field/pandas.py +++ b/records_mover/records/schema/field/pandas.py @@ -1,13 +1,64 @@ +import pandas as pd from pandas import Series, Index -from typing import Any, Type, TYPE_CHECKING +from typing import Any, Type, TYPE_CHECKING, Optional, Mapping, Union from .statistics import RecordsSchemaFieldStringStatistics from ...processing_instructions import ProcessingInstructions from .representation import RecordsSchemaFieldRepresentation +from ....utils.limits import IntegerType from .numpy import details_from_numpy_dtype import numpy as np if TYPE_CHECKING: from ..field import RecordsSchemaField # noqa from ..schema import RecordsSchema # noqa + from pandas.core.dtypes.dtypes import ExtensionDtype # noqa + +# Cribbed from non-public https://github.com/pandas-dev/pandas/blob/v1.2.1/pandas/_typing.py +Dtype = Union[ + "ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]] +] +DtypeObj = Union[np.dtype, "ExtensionDtype"] + + +def supports_nullable_ints() -> bool: + """Detects if this version of pandas supports nullable int extension types.""" + return 'Int64Dtype' in dir(pd) + + +def integer_type_mapping(use_extension_types: bool) -> Mapping[IntegerType, DtypeObj]: + if use_extension_types: + return { + IntegerType.INT8: pd.Int8Dtype(), + IntegerType.UINT8: pd.UInt8Dtype(), + IntegerType.INT16: pd.Int16Dtype(), + IntegerType.UINT16: pd.UInt16Dtype(), + IntegerType.INT24: pd.Int32Dtype(), + IntegerType.UINT24: pd.Int32Dtype(), + IntegerType.INT32: pd.Int32Dtype(), + IntegerType.UINT32: pd.UInt32Dtype(), + IntegerType.INT64: pd.Int64Dtype(), + IntegerType.UINT64: pd.UInt64Dtype(), + } + else: + return { + IntegerType.INT8: np.int8, + IntegerType.UINT8: np.uint8, + IntegerType.INT16: np.int16, + IntegerType.UINT16: np.uint16, + IntegerType.INT24: np.int32, + IntegerType.UINT24: np.uint32, + IntegerType.INT32: np.int32, + IntegerType.UINT32: np.uint32, + IntegerType.INT64: np.int64, + IntegerType.UINT64: np.uint64, + } + + +def integer_type_for_range(min_: int, max_: int, has_extension_types: bool) -> Optional[DtypeObj]: + int_type = IntegerType.smallest_cover_for(min_, max_) + if int_type: + return integer_type_mapping(has_extension_types).get(int_type) + else: + return None def field_from_index(index: Index, diff --git a/records_mover/utils/limits.py b/records_mover/utils/limits.py index f6081b6cc..4d4fa0b4b 100644 --- a/records_mover/utils/limits.py +++ b/records_mover/utils/limits.py @@ -1,4 +1,6 @@ +from enum import Enum import math +from typing import Optional INT8_MAX = 127 INT8_MIN = -128 @@ -26,6 +28,37 @@ FLOAT80_SIGNIFICAND_BITS = 64 +class IntegerType(Enum): + INT8 = (INT8_MIN, INT8_MAX) + UINT8 = (UINT8_MIN, UINT8_MAX) + INT16 = (INT16_MIN, INT16_MAX) + UINT16 = (UINT16_MIN, UINT16_MAX) + INT24 = (INT24_MIN, INT24_MAX) + UINT24 = (UINT24_MIN, UINT24_MAX) + INT32 = (INT32_MIN, INT32_MAX) + UINT32 = (UINT32_MIN, UINT32_MAX) + INT64 = (INT64_MIN, INT64_MAX) + UINT64 = (UINT64_MIN, UINT64_MAX) + + def __init__(self, min_: int, max_: int): + self.min_ = min_ + self.max_ = max_ + + def is_cover_for(self, low_value: int, high_value: int) -> bool: + return low_value >= self.min_ and high_value <= self.max_ + + @property + def range(self): + return (self.min_, self.max_) + + @classmethod + def smallest_cover_for(cls, low_value: int, high_value: int) -> Optional['IntegerType']: + for int_type in cls: + if int_type.is_cover_for(low_value, high_value): + return int_type + return None + + # https://stackoverflow.com/questions/2189800/length-of-an-integer-in-python def num_digits(n: int) -> int: if n > 0: diff --git a/setup.cfg b/setup.cfg index 17c0ef619..1b5f080f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,3 +76,6 @@ ignore_missing_imports = True [mypy-pyarrow.*] ignore_missing_imports = True + +[mypy-nose.*] +ignore_missing_imports = True diff --git a/tests/component/records/schema/field/test_dtype.py b/tests/component/records/schema/field/test_dtype.py new file mode 100644 index 000000000..4498fe0e0 --- /dev/null +++ b/tests/component/records/schema/field/test_dtype.py @@ -0,0 +1,132 @@ +from nose.tools import assert_equal +from mock import patch +from records_mover.records.schema.field import RecordsSchemaField +from records_mover.records.schema.field.constraints import ( + RecordsSchemaFieldIntegerConstraints, + RecordsSchemaFieldDecimalConstraints, +) +import numpy as np +import pandas as pd + + +def with_nullable(nullable: bool, fn): + def wrapfn(*args, **kwargs): + with patch( + "records_mover.records.schema.field.pandas.supports_nullable_ints", + return_value=nullable, + ): + fn(*args, **kwargs) + + return wrapfn + + +def check_dtype(field_type, constraints, expectation): + field = RecordsSchemaField( + name="test", + field_type=field_type, + constraints=constraints, + statistics=None, + representations=None, + ) + out = field.cast_series_type(pd.Series(1, dtype=np.int8)) + assert_equal(out.dtype, expectation) + + +def test_to_pandas_dtype_integer_no_nullable(): + expectations = { + (-100, 100): np.int8, + (0, 240): np.uint8, + (-10000, 10000): np.int16, + (500, 40000): np.uint16, + (-200000000, 200000000): np.int32, + (25, 4000000000): np.uint32, + (-9000000000000000000, 2000000000): np.int64, + (25, 10000000000000000000): np.uint64, + (25, 1000000000000000000000000000): np.float128, + (None, None): np.int64, + } + for (min_, max_), expected_pandas_type in expectations.items(): + constraints = RecordsSchemaFieldIntegerConstraints( + required=True, unique=None, min_=min_, max_=max_ + ) + yield with_nullable( + False, check_dtype + ), "integer", constraints, expected_pandas_type + + +def test_to_pandas_dtype_integer_nullable(): + expectations = { + (-100, 100): pd.Int8Dtype(), + (0, 240): pd.UInt8Dtype(), + (-10000, 10000): pd.Int16Dtype(), + (500, 40000): pd.UInt16Dtype(), + (-200000000, 200000000): pd.Int32Dtype(), + (25, 4000000000): pd.UInt32Dtype(), + (-9000000000000000000, 2000000000): pd.Int64Dtype(), + (25, 10000000000000000000): pd.UInt64Dtype(), + (25, 1000000000000000000000000000): np.float128, + (None, None): pd.Int64Dtype(), + } + for (min_, max_), expected_pandas_type in expectations.items(): + constraints = RecordsSchemaFieldIntegerConstraints( + required=True, unique=None, min_=min_, max_=max_ + ) + yield with_nullable( + True, check_dtype + ), "integer", constraints, expected_pandas_type + + +def test_to_pandas_dtype_decimal_float(): + expectations = { + (8, 4): np.float16, + (20, 10): np.float32, + (40, 20): np.float64, + (80, 64): np.float128, + (500, 250): np.float128, + (None, None): np.float64, + } + for ( + fp_total_bits, + fp_significand_bits, + ), expected_pandas_type in expectations.items(): + constraints = RecordsSchemaFieldDecimalConstraints( + required=False, + unique=None, + fixed_precision=None, + fixed_scale=None, + fp_total_bits=fp_total_bits, + fp_significand_bits=fp_significand_bits, + ) + yield check_dtype, "decimal", constraints, expected_pandas_type + + +def test_to_pandas_dtype_misc(): + expectations = { + "boolean": np.bool_, + "string": np.object_, + "date": np.object_, + "datetime": "datetime64[ns]", + "datetimetz": "datetime64[ns, UTC]", + "time": np.object_, + } + for field_type, expected_pandas_type in expectations.items(): + yield check_dtype, field_type, None, expected_pandas_type + + +def test_to_pandas_dtype_fixed_precision_(): + check_dtype( + "decimal", + RecordsSchemaFieldDecimalConstraints( + required=False, + unique=None, + fixed_precision=1, + fixed_scale=1, + fp_total_bits=None, + fp_significand_bits=None, + ), + np.float64, + ) + + +def test_to_pandas_dtype_decimal_no_constraints(): + check_dtype("decimal", None, np.float64) diff --git a/tests/integration/records/multi_db/test_records_table2table.py b/tests/integration/records/multi_db/test_records_table2table.py index c63d4f922..2c9a528a7 100644 --- a/tests/integration/records/multi_db/test_records_table2table.py +++ b/tests/integration/records/multi_db/test_records_table2table.py @@ -103,7 +103,10 @@ def source2target(self): if __name__ == '__main__': - set_stream_logging() + set_stream_logging(level=logging.DEBUG) + logging.getLogger('botocore').setLevel(logging.INFO) + logging.getLogger('boto3').setLevel(logging.INFO) + logging.getLogger('urllib3').setLevel(logging.INFO) for source in DB_TYPES: for target in DB_TYPES: diff --git a/tests/integration/records/single_db/base_records_test.py b/tests/integration/records/single_db/base_records_test.py index 4aa26192e..284e1e988 100644 --- a/tests/integration/records/single_db/base_records_test.py +++ b/tests/integration/records/single_db/base_records_test.py @@ -29,7 +29,10 @@ logger = logging.getLogger(__name__) -set_stream_logging() +set_stream_logging(level=logging.DEBUG) +logging.getLogger('botocore').setLevel(logging.INFO) +logging.getLogger('boto3').setLevel(logging.INFO) +logging.getLogger('urllib3').setLevel(logging.INFO) class BaseRecordsIntegrationTest(unittest.TestCase): diff --git a/tests/integration/resources/_schema.json b/tests/integration/resources/_schema.json index 41c1b2a7a..57aebe79d 100644 --- a/tests/integration/resources/_schema.json +++ b/tests/integration/resources/_schema.json @@ -7,7 +7,7 @@ }, "numstr": { "type": "string", - "index": 1, + "index": 2, "constraints": { "max_length_bytes": 3 } diff --git a/tests/unit/records/schema/field/test_field.py b/tests/unit/records/schema/field/test_field.py index 800934fa9..b91e1831f 100644 --- a/tests/unit/records/schema/field/test_field.py +++ b/tests/unit/records/schema/field/test_field.py @@ -95,118 +95,6 @@ def test_to_sqlalchemy_column(self, mock_field_to_sqlalchemy_column): mock_field_to_sqlalchemy_column.assert_called_with(field, mock_driver) self.assertEqual(out, mock_field_to_sqlalchemy_column.return_value) - def test_to_numpy_dtype_integer(self): - mock_name = Mock(name='name') - mock_statistics = Mock(name='statistics') - mock_representations = Mock(name='representations') - mock_field_type = 'integer' - expectations = { - (-100, 100): np.int8, - (0, 240): np.uint8, - (-10000, 10000): np.int16, - (500, 40000): np.uint16, - (-200000000, 200000000): np.int32, - (25, 4000000000): np.uint32, - (-9000000000000000000, 2000000000): np.int64, - (25, 10000000000000000000): np.uint64, - (25, 1000000000000000000000000000): np.float128, - (None, None): np.int64 - } - for (mock_min, mock_max), expected_pandas_type in expectations.items(): - mock_constraints = Mock(name='constraints') - mock_constraints.min_ = mock_min - mock_constraints.max_ = mock_max - field = RecordsSchemaField(name=mock_name, - field_type=mock_field_type, - constraints=mock_constraints, - statistics=mock_statistics, - representations=mock_representations) - - out = field.to_numpy_dtype() - self.assertEqual(out, expected_pandas_type, f"min={mock_min}, max={mock_max}") - - def test_to_numpy_dtype_decimal_float(self): - mock_name = Mock(name='name') - mock_statistics = Mock(name='statistics') - mock_representations = Mock(name='representations') - mock_field_type = 'decimal' - expectations = { - (8, 4): np.float16, - (20, 10): np.float32, - (40, 20): np.float64, - (80, 64): np.float128, - (500, 250): np.float128, - (None, None): np.float64, - } - for (fp_total_bits, fp_significand_bits), expected_pandas_type in expectations.items(): - mock_constraints = Mock(name='constraints') - mock_constraints.fixed_precision = None - mock_constraints.fixed_scale = None - mock_constraints.fp_total_bits = fp_total_bits - mock_constraints.fp_significand_bits = fp_significand_bits - field = RecordsSchemaField(name=mock_name, - field_type=mock_field_type, - constraints=mock_constraints, - statistics=mock_statistics, - representations=mock_representations) - - out = field.to_numpy_dtype() - self.assertEqual(out, expected_pandas_type) - - def test_to_numpy_dtype_decimal_no_constraints(self): - mock_name = Mock(name='name') - mock_statistics = Mock(name='statistics') - mock_representations = Mock(name='representations') - mock_field_type = 'decimal' - field = RecordsSchemaField(name=mock_name, - field_type=mock_field_type, - constraints=None, - statistics=mock_statistics, - representations=mock_representations) - - out = field.to_numpy_dtype() - self.assertEqual(out, np.float64) - - def test_to_numpy_dtype_fixed_precision_(self): - mock_name = Mock(name='name') - mock_statistics = Mock(name='statistics') - mock_representations = Mock(name='representations') - mock_constraints = Mock(name='constraints') - mock_constraints.fixed_precision = 1 - mock_constraints.fixed_scale = 1 - mock_field_type = 'decimal' - field = RecordsSchemaField(name=mock_name, - field_type=mock_field_type, - constraints=mock_constraints, - statistics=mock_statistics, - representations=mock_representations) - - out = field.to_numpy_dtype() - self.assertEqual(out, np.float64) - - def test_to_numpy_dtype_misc(self): - mock_name = Mock(name='name') - mock_constraints = Mock(name='constraints') - mock_statistics = Mock(name='statistics') - mock_representations = Mock(name='representations') - expectations = { - 'boolean': np.bool_, - 'string': np.object_, - 'date': np.object_, - 'datetime': 'datetime64[ns]', - 'datetimetz': 'datetime64[ns, UTC]', - 'time': np.object_, - } - for field_type, expected_pandas_type in expectations.items(): - field = RecordsSchemaField(name=mock_name, - field_type=field_type, - constraints=mock_constraints, - statistics=mock_statistics, - representations=mock_representations) - - out = field.to_numpy_dtype() - self.assertEqual(out, expected_pandas_type) - def test_python_type_to_field_type(self): mock_unknown_type = Mock(name='unknown_type') out = RecordsSchemaField.python_type_to_field_type(mock_unknown_type) diff --git a/tests/unit/records/sources/test_fileobjs.py b/tests/unit/records/sources/test_fileobjs.py index d6dcc93ee..fa1d3c8b3 100644 --- a/tests/unit/records/sources/test_fileobjs.py +++ b/tests/unit/records/sources/test_fileobjs.py @@ -193,3 +193,20 @@ def test_move_to_records_directory(self, mock_MoveResult.assert_called_with(move_count=None, output_urls={'file.mumble': 'vmb://dir/file.mumble'}) self.assertEqual(out, mock_MoveResult.return_value) + + @patch('records_mover.records.sources.fileobjs.TemporaryDirectory') + @patch('records_mover.records.sources.fileobjs.FilesystemDirectoryUrl') + def test_temporary_unloadable_directory_loc(self, + mock_FilesystemDirectoryUrl, + mock_TemporaryDirectory): + mock_records_format = Mock(name='records_format') + mock_records_schema = Mock(name='records_schema') + mock_records_schema = Mock(name='records_schema') + mock_target_names_to_input_fileobjs = Mock(name='target_names_to_input_fileobjs') + source = FileobjsSource(target_names_to_input_fileobjs=mock_target_names_to_input_fileobjs, + records_schema=mock_records_schema, + records_format=mock_records_format) + with source.temporary_unloadable_directory_loc() as temp_loc: + self.assertEqual(temp_loc, mock_FilesystemDirectoryUrl.return_value) + mock_temp_dir = mock_TemporaryDirectory.return_value.__enter__.return_value + mock_FilesystemDirectoryUrl.assert_called_with(mock_temp_dir)