From c1b79313f0aa6d1bcbef73a3a1a3471512ecfce8 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 30 Aug 2023 17:28:40 -0500 Subject: [PATCH] Raise an error when timezone subtypes are encountered in `pd.IntervalDtype` (#14006) closes #14004 This PR raises an error when an `IntervalIndex` contains a timezone-aware sub-type so that we don't go into infinite recursion. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/14006 --- python/cudf/cudf/core/column/column.py | 8 ++++++-- python/cudf/cudf/tests/test_interval.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index d60f426c642..ad761ea8d18 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2261,8 +2261,12 @@ def as_column( data = ColumnBase.from_scalar(arbitrary, length if length else 1) elif isinstance(arbitrary, pd.core.arrays.masked.BaseMaskedArray): data = as_column(pa.Array.from_pandas(arbitrary), dtype=dtype) - elif isinstance(arbitrary, pd.DatetimeIndex) and isinstance( - arbitrary.dtype, pd.DatetimeTZDtype + elif ( + isinstance(arbitrary, pd.DatetimeIndex) + and isinstance(arbitrary.dtype, pd.DatetimeTZDtype) + ) or ( + isinstance(arbitrary, pd.IntervalIndex) + and is_datetime64tz_dtype(arbitrary.dtype.subtype) ): raise NotImplementedError( "cuDF does not yet support timezone-aware datetimes" diff --git a/python/cudf/cudf/tests/test_interval.py b/python/cudf/cudf/tests/test_interval.py index f2e8f585a69..9704be44b95 100644 --- a/python/cudf/cudf/tests/test_interval.py +++ b/python/cudf/cudf/tests/test_interval.py @@ -165,3 +165,19 @@ def test_interval_index_unique(): actual = gi.unique() assert_eq(expected, actual) + + +@pytest.mark.parametrize("tz", ["US/Eastern", None]) +def test_interval_with_datetime(tz): + dti = pd.date_range( + start=pd.Timestamp("20180101", tz=tz), + end=pd.Timestamp("20181231", tz=tz), + freq="M", + ) + pidx = pd.IntervalIndex.from_breaks(dti) + if tz is None: + gidx = cudf.from_pandas(pidx) + assert_eq(pidx, gidx) + else: + with pytest.raises(NotImplementedError): + cudf.from_pandas(pidx)