Skip to content

Commit

Permalink
Handle closed property in IntervalDtype.from_pandas (NVIDIA#10798)
Browse files Browse the repository at this point in the history
Since v1.3, `pandas.IntervalDtype` also has a `closed` property, so handle that in `IntervalDtype.from_pandas`.

While we're here, add a more reasonable hash and equality (rather than deferring to `StructDtype`), fixing the previous behaviour that:
```python
from cudf import IntervalDtype
dt1 = IntervalDtype("int32", "both")
dt2 = IntervalDtype("int32", "right")
dtypes = set([dt1, dt2])
print(len(dtypes)) => 1
```

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: rapidsai/cudf#10798
  • Loading branch information
wence- authored May 25, 2022
1 parent 31e1739 commit 5165319
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
27 changes: 22 additions & 5 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import cudf
from cudf._typing import Dtype
from cudf.core._compat import PANDAS_GE_130
from cudf.core.abc import Serializable
from cudf.core.buffer import Buffer

Expand Down Expand Up @@ -545,7 +546,7 @@ class IntervalDtype(StructDtype):
"""
subtype: str, np.dtype
The dtype of the Interval bounds.
closed: {right’, ‘left’, ‘both’, ‘neither}, default right
closed: {'right', 'left', 'both', 'neither'}, default 'right'
Whether the interval is closed on the left-side, right-side,
both or neither. See the Notes for more detailed explanation.
"""
Expand All @@ -555,6 +556,8 @@ class IntervalDtype(StructDtype):
def __init__(self, subtype, closed="right"):
super().__init__(fields={"left": subtype, "right": subtype})

if closed is None:
closed = "right"
if closed in ["left", "right", "neither", "both"]:
self.closed = closed
else:
Expand All @@ -565,7 +568,7 @@ def subtype(self):
return self.fields["left"]

def __repr__(self):
return f"interval[{self.fields['left']}]"
return f"interval[{self.subtype}, {self.closed}]"

@classmethod
def from_arrow(cls, typ):
Expand All @@ -579,9 +582,23 @@ def to_arrow(self):

@classmethod
def from_pandas(cls, pd_dtype: pd.IntervalDtype) -> "IntervalDtype":
return cls(
subtype=pd_dtype.subtype
) # TODO: needs `closed` when we upgrade Pandas
if PANDAS_GE_130:
return cls(subtype=pd_dtype.subtype, closed=pd_dtype.closed)
else:
return cls(subtype=pd_dtype.subtype)

def __eq__(self, other):
if isinstance(other, str):
# This means equality isn't transitive but mimics pandas
return other == self.name
return (
type(self) == type(other)
and self.subtype == other.subtype
and self.closed == other.closed
)

def __hash__(self):
return hash((self.subtype, self.closed))

def serialize(self) -> Tuple[dict, list]:
header = {
Expand Down
28 changes: 24 additions & 4 deletions python/cudf/cudf/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import cudf
from cudf.core._compat import PANDAS_GE_130
from cudf.core.column import ColumnBase
from cudf.core.dtypes import (
CategoricalDtype,
Expand Down Expand Up @@ -164,15 +165,34 @@ def test_max_precision(decimal_type, max_precision):
decimal_type(scale=0, precision=max_precision + 1)


@pytest.mark.parametrize("fields", ["int64", "int32"])
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
def test_interval_dtype_pyarrow_round_trip(fields, closed):
pa_array = pd.core.arrays._arrow_utils.ArrowIntervalType(fields, closed)
@pytest.fixture(params=["int64", "int32"])
def subtype(request):
return request.param


@pytest.fixture(params=["left", "right", "both", "neither"])
def closed(request):
return request.param


def test_interval_dtype_pyarrow_round_trip(subtype, closed):
pa_array = pd.core.arrays._arrow_utils.ArrowIntervalType(subtype, closed)
expect = pa_array
got = IntervalDtype.from_arrow(expect).to_arrow()
assert expect.equals(got)


@pytest.mark.skipif(
not PANDAS_GE_130,
reason="pandas<1.3.0 doesn't have a closed argument for IntervalDtype",
)
def test_interval_dtype_from_pandas(subtype, closed):
expect = cudf.IntervalDtype(subtype, closed=closed)
pd_type = pd.IntervalDtype(subtype, closed=closed)
got = cudf.IntervalDtype.from_pandas(pd_type)
assert expect == got


def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array):
"""
In cudf, each column holds its dtype. And since column may have child
Expand Down

0 comments on commit 5165319

Please sign in to comment.