Skip to content

Commit

Permalink
REF/API: DatetimeTZDtype
Browse files Browse the repository at this point in the history
* Remove magic constructor from string
* Remove Caching

The remaining changes in the DatetimeArray PR will be to

1. Inherit from ExtensionDtype
2. Implement construct_array_type
3. Register
  • Loading branch information
TomAugspurger committed Nov 29, 2018
1 parent 580a094 commit 1ca7fa4
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 101 deletions.
25 changes: 15 additions & 10 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,16 +978,21 @@ def validate_tz_from_dtype(dtype, tz):
ValueError : on tzinfo mismatch
"""
if dtype is not None:
try:
dtype = DatetimeTZDtype.construct_from_string(dtype)
dtz = getattr(dtype, 'tz', None)
if dtz is not None:
if tz is not None and not timezones.tz_compare(tz, dtz):
raise ValueError("cannot supply both a tz and a dtype"
" with a tz")
tz = dtz
except TypeError:
pass
if isinstance(dtype, compat.string_types):
try:
dtype = DatetimeTZDtype.construct_from_string(dtype)
except TypeError:
# Things like `datetime64[ns]`, which is OK for the
# constructors, but also nonsense, which should be validated
# but not by us. We *do* allow non-existent tz errors to
# go through
pass
dtz = getattr(dtype, 'tz', None)
if dtz is not None:
if tz is not None and not timezones.tz_compare(tz, dtz):
raise ValueError("cannot supply both a tz and a dtype"
" with a tz")
tz = dtz
return tz


Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,7 +1789,7 @@ def _coerce_to_dtype(dtype):
ordered = getattr(dtype, 'ordered', False)
dtype = CategoricalDtype(categories=categories, ordered=ordered)
elif is_datetime64tz_dtype(dtype):
dtype = DatetimeTZDtype(dtype)
dtype = DatetimeTZDtype.construct_from_string(dtype)
elif is_period_dtype(dtype):
dtype = PeriodDtype(dtype)
elif is_interval_dtype(dtype):
Expand Down
130 changes: 69 additions & 61 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" define extension dtypes """

import re

import numpy as np
import pytz

from pandas._libs.interval import Interval
from pandas._libs.tslibs import NaT, Period, Timestamp, timezones
Expand Down Expand Up @@ -483,99 +483,103 @@ class DatetimeTZDtype(PandasExtensionDtype):
str = '|M8[ns]'
num = 101
base = np.dtype('M8[ns]')
na_value = NaT
_metadata = ('unit', 'tz')
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
_cache = {}
# TODO: restore caching? who cares though? It seems needlessly complex.
# np.dtype('datetime64[ns]') isn't a singleton

def __new__(cls, unit=None, tz=None):
""" Create a new unit if needed, otherwise return from the cache
def __init__(self, unit="ns", tz=None):
"""
An ExtensionDtype for timezone-aware datetime data.
Parameters
----------
unit : string unit that this represents, currently must be 'ns'
tz : string tz that this represents
"""
unit : str, default "ns"
The precision of the datetime data. Currently limited
to ``"ns"``.
tz : str, int, or datetime.tzinfo
The timezone.
Raises
------
pytz.UnknownTimeZoneError
When the requested timezone cannot be found.
Examples
--------
>>> pd.core.dtypes.dtypes.DatetimeTZDtype(tz='UTC')
datetime64[ns, UTC]
>>> pd.core.dtypes.dtypes.DatetimeTZDtype(tz='dateutil/US/Central')
datetime64[ns, tzfile('/usr/share/zoneinfo/US/Central')]
"""
if isinstance(unit, DatetimeTZDtype):
unit, tz = unit.unit, unit.tz

elif unit is None:
# we are called as an empty constructor
# generally for pickle compat
return object.__new__(cls)
if unit != 'ns':
raise ValueError("DatetimeTZDtype only supports ns units")

if tz:
tz = timezones.maybe_get_tz(tz)
elif tz is not None:
raise pytz.UnknownTimeZoneError(tz)
elif tz is None:
raise TypeError("A 'tz' is required.")

# we were passed a string that we can construct
try:
m = cls._match.search(unit)
if m is not None:
unit = m.groupdict()['unit']
tz = timezones.maybe_get_tz(m.groupdict()['tz'])
except TypeError:
raise ValueError("could not construct DatetimeTZDtype")

elif isinstance(unit, compat.string_types):

if unit != 'ns':
raise ValueError("DatetimeTZDtype only supports ns units")
self._unit = unit
self._tz = tz

unit = unit
tz = tz

if tz is None:
raise ValueError("DatetimeTZDtype constructor must have a tz "
"supplied")

# hash with the actual tz if we can
# some cannot be hashed, so stringfy
try:
key = (unit, tz)
hash(key)
except TypeError:
key = (unit, str(tz))
@property
def unit(self):
"""The precision of the datetime data."""
return self._unit

# set/retrieve from cache
try:
return cls._cache[key]
except KeyError:
u = object.__new__(cls)
u.unit = unit
u.tz = tz
cls._cache[key] = u
return u
@property
def tz(self):
"""The timezone."""
return self._tz

@classmethod
def construct_array_type(cls):
"""Return the array type associated with this dtype
Returns
-------
type
def construct_from_string(cls, string):
"""
from pandas import DatetimeIndex
return DatetimeIndex
Construct a DatetimeTZDtype from a string.
@classmethod
def construct_from_string(cls, string):
""" attempt to construct this type from a string, raise a TypeError if
it's not possible
Parameters
----------
string : str
The string alias for this DatetimeTZDtype.
Should be formatted like ``datetime64[ns, <tz>]``,
where ``<tz>`` is the timezone name.
Examples
--------
>>> DatetimeTZDtype.construct_from_string('datetime64[ns, UTC]')
datetime64[ns, UTC]
"""
msg = "could not construct DatetimeTZDtype"""
try:
return cls(unit=string)
match = cls._match.match(string)
if match:
d = match.groupdict()
return cls(unit=d['unit'], tz=d['tz'])
else:
raise TypeError(msg)
except ValueError:
raise TypeError("could not construct DatetimeTZDtype")
raise TypeError(msg)

def __unicode__(self):
# format the tz
return "datetime64[{unit}, {tz}]".format(unit=self.unit, tz=self.tz)

@property
def name(self):
"""A string representation of the dtype."""
return str(self)

def __hash__(self):
# make myself hashable
# TODO: update this.
return hash(str(self))

def __eq__(self, other):
Expand All @@ -586,6 +590,10 @@ def __eq__(self, other):
self.unit == other.unit and
str(self.tz) == str(other.tz))

def __getstate__(self):
# for pickle compat.
return self.__dict__


class PeriodDtype(ExtensionDtype, PandasExtensionDtype):
"""
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def test_numpy_string_dtype(self):
'datetime64[ns, US/Eastern]',
'datetime64[ns, Asia/Tokyo]',
'datetime64[ns, UTC]'])
@pytest.mark.xfail(reason="dtype-caching", strict=True)
def test_datetimetz_dtype(self, dtype):
assert com.pandas_dtype(dtype) is DatetimeTZDtype(dtype)
assert com.pandas_dtype(dtype) == DatetimeTZDtype(dtype)
assert (com.pandas_dtype(dtype) is
DatetimeTZDtype.construct_from_string(dtype))
assert (com.pandas_dtype(dtype) ==
DatetimeTZDtype.construct_from_string(dtype))
assert com.pandas_dtype(dtype) == dtype

def test_categorical_dtype(self):
Expand Down
54 changes: 28 additions & 26 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,20 @@ def test_hash_vs_equality(self):
assert dtype == dtype2
assert dtype2 == dtype
assert dtype3 == dtype
assert dtype is dtype2
assert dtype2 is dtype
assert dtype3 is dtype
assert hash(dtype) == hash(dtype2)
assert hash(dtype) == hash(dtype3)

dtype4 = DatetimeTZDtype("ns", "US/Central")
assert dtype2 != dtype4
assert hash(dtype2) != hash(dtype4)

def test_construction(self):
pytest.raises(ValueError,
lambda: DatetimeTZDtype('ms', 'US/Eastern'))

def test_subclass(self):
a = DatetimeTZDtype('datetime64[ns, US/Eastern]')
b = DatetimeTZDtype('datetime64[ns, CET]')
a = DatetimeTZDtype.construct_from_string('datetime64[ns, US/Eastern]')
b = DatetimeTZDtype.construct_from_string('datetime64[ns, CET]')

assert issubclass(type(a), type(a))
assert issubclass(type(a), type(b))
Expand All @@ -189,8 +190,6 @@ def test_compat(self):
assert not is_datetime64_dtype('datetime64[ns, US/Eastern]')

def test_construction_from_string(self):
result = DatetimeTZDtype('datetime64[ns, US/Eastern]')
assert is_dtype_equal(self.dtype, result)
result = DatetimeTZDtype.construct_from_string(
'datetime64[ns, US/Eastern]')
assert is_dtype_equal(self.dtype, result)
Expand Down Expand Up @@ -255,14 +254,13 @@ def test_dst(self):
def test_parser(self, tz, constructor):
# pr #11245
dtz_str = '{con}[ns, {tz}]'.format(con=constructor, tz=tz)
result = DatetimeTZDtype(dtz_str)
result = DatetimeTZDtype.construct_from_string(dtz_str)
expected = DatetimeTZDtype('ns', tz)
assert result == expected

def test_empty(self):
dt = DatetimeTZDtype()
with pytest.raises(AttributeError):
str(dt)
with pytest.raises(TypeError, match="A 'tz' is required."):
DatetimeTZDtype()


class TestPeriodDtype(Base):
Expand Down Expand Up @@ -795,34 +793,38 @@ def test_update_dtype_errors(self, bad_dtype):
dtype.update_dtype(bad_dtype)


@pytest.mark.parametrize(
'dtype',
[CategoricalDtype, IntervalDtype])
@pytest.mark.parametrize('dtype', [
CategoricalDtype,
IntervalDtype,
])
def test_registry(dtype):
assert dtype in registry.dtypes


@pytest.mark.parametrize('dtype', [DatetimeTZDtype, PeriodDtype])
@pytest.mark.parametrize('dtype', [
PeriodDtype,
DatetimeTZDtype,
])
def test_pandas_registry(dtype):
assert dtype not in registry.dtypes
assert dtype in _pandas_registry.dtypes


@pytest.mark.parametrize(
'dtype, expected',
[('int64', None),
('interval', IntervalDtype()),
('interval[int64]', IntervalDtype()),
('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')),
('category', CategoricalDtype())])
@pytest.mark.parametrize('dtype, expected', [
('int64', None),
('interval', IntervalDtype()),
('interval[int64]', IntervalDtype()),
('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')),
('category', CategoricalDtype()),
])
def test_registry_find(dtype, expected):
assert registry.find(dtype) == expected


@pytest.mark.parametrize(
'dtype, expected',
[('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
@pytest.mark.parametrize('dtype, expected', [
('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern')),
])
def test_pandas_registry_find(dtype, expected):
assert _pandas_registry.find(dtype) == expected

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_array_equivalent_str():
# Datetime-like
(np.dtype("M8[ns]"), NaT),
(np.dtype("m8[ns]"), NaT),
(DatetimeTZDtype('datetime64[ns, US/Eastern]'), NaT),
(DatetimeTZDtype.construct_from_string('datetime64[ns, US/Eastern]'), NaT),
(PeriodDtype("M"), NaT),
# Integer
('u1', 0), ('u2', 0), ('u4', 0), ('u8', 0),
Expand Down

0 comments on commit 1ca7fa4

Please sign in to comment.