Skip to content

Commit

Permalink
Refactor and add validation to IntervalIndex.__init__ (#14778)
Browse files Browse the repository at this point in the history
* Adding validation to `closed`, `dtype` arguments in `ItervalIndex.__init__`
* Ensure `closed` attribute always maps to `IntervalDtype.closed`
* `build_interval_column` was no longer necessary by using `IntervalColumn` directly

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14778
  • Loading branch information
mroeschke authored Jan 23, 2024
1 parent a39897c commit c83b9fd
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 114 deletions.
50 changes: 2 additions & 48 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,14 +999,14 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
"`.astype('str')` instead."
)
return col.as_string_column(dtype)
elif isinstance(dtype, IntervalDtype):
return col.as_interval_column(dtype)
elif isinstance(dtype, (ListDtype, StructDtype)):
if not col.dtype == dtype:
raise NotImplementedError(
f"Casting {self.dtype} columns not currently supported"
)
return col
elif isinstance(dtype, IntervalDtype):
return col.as_interval_column(dtype)
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
return col.as_decimal_column(dtype)
elif np.issubdtype(cast(Any, dtype), np.datetime64):
Expand Down Expand Up @@ -1689,52 +1689,6 @@ def build_categorical_column(
return cast("cudf.core.column.CategoricalColumn", result)


def build_interval_column(
left_col,
right_col,
mask=None,
size=None,
offset=0,
null_count=None,
closed="right",
):
"""
Build an IntervalColumn
Parameters
----------
left_col : Column
Column of values representing the left of the interval
right_col : Column
Column of representing the right of the interval
mask : Buffer
Null mask
size : int, optional
offset : int, optional
closed : {"left", "right", "both", "neither"}, default "right"
Whether the intervals are closed on the left-side, right-side,
both or neither.
"""
left = as_column(left_col)
right = as_column(right_col)
if closed not in {"left", "right", "both", "neither"}:
closed = "right"
if type(left_col) is not list:
dtype = IntervalDtype(left_col.dtype, closed)
else:
dtype = IntervalDtype("int64", closed)
size = len(left)
return build_column(
data=None,
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
children=(left, right),
)


def build_list_column(
indices: ColumnBase,
elements: ColumnBase,
Expand Down
25 changes: 6 additions & 19 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(
offset=0,
null_count=None,
children=(),
closed="right",
):
super().__init__(
data=None,
Expand All @@ -29,14 +28,6 @@ def __init__(
null_count=null_count,
children=children,
)
if closed in ["left", "right", "neither", "both"]:
self._closed = closed
else:
raise ValueError("closed value is not valid")

@property
def closed(self):
return self._closed

@classmethod
def from_arrow(cls, data):
Expand All @@ -50,7 +41,6 @@ def from_arrow(cls, data):
offset = data.offset
null_count = data.null_count
children = new_col.children
closed = dtype.closed

return IntervalColumn(
size=size,
Expand All @@ -59,7 +49,6 @@ def from_arrow(cls, data):
offset=offset,
null_count=null_count,
children=children,
closed=closed,
)

def to_arrow(self):
Expand All @@ -73,7 +62,7 @@ def to_arrow(self):

@classmethod
def from_struct_column(cls, struct_column: StructColumn, closed="right"):
first_field_name = list(struct_column.dtype.fields.keys())[0]
first_field_name = next(iter(struct_column.dtype.fields.keys()))
return IntervalColumn(
size=struct_column.size,
dtype=IntervalDtype(
Expand All @@ -83,20 +72,19 @@ def from_struct_column(cls, struct_column: StructColumn, closed="right"):
offset=struct_column.offset,
null_count=struct_column.null_count,
children=struct_column.base_children,
closed=closed,
)

def copy(self, deep=True):
closed = self.closed
struct_copy = super().copy(deep=deep)
return IntervalColumn(
size=struct_copy.size,
dtype=IntervalDtype(struct_copy.dtype.fields["left"], closed),
dtype=IntervalDtype(
struct_copy.dtype.fields["left"], self.dtype.closed
),
mask=struct_copy.base_mask,
offset=struct_copy.offset,
null_count=struct_copy.null_count,
children=struct_copy.base_children,
closed=closed,
)

def as_interval_column(self, dtype):
Expand All @@ -109,7 +97,7 @@ def as_interval_column(self, dtype):
# when creating an interval series or interval dataframe
if dtype == "interval":
dtype = IntervalDtype(
self.dtype.fields["left"], self.closed
self.dtype.subtype, self.dtype.closed
)
children = self.children
return IntervalColumn(
Expand All @@ -119,7 +107,6 @@ def as_interval_column(self, dtype):
offset=self.offset,
null_count=self.null_count,
children=children,
closed=dtype.closed,
)
else:
raise ValueError("dtype must be IntervalDtype")
Expand All @@ -141,5 +128,5 @@ def to_pandas(
def element_indexing(self, index: int):
result = super().element_indexing(index)
if cudf.get_option("mode.pandas_compatible"):
return pd.Interval(**result, closed=self._closed)
return pd.Interval(**result, closed=self.dtype.closed)
return result
120 changes: 81 additions & 39 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3174,10 +3174,12 @@ def interval_range(
data = column.column_empty_like_same_mask(left_col, dtype)
return IntervalIndex(data, closed=closed)

interval_col = column.build_interval_column(
left_col, right_col, closed=closed
interval_col = IntervalColumn(
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
)
return IntervalIndex(interval_col)
return IntervalIndex(interval_col, closed=closed)


class IntervalIndex(GenericIndex):
Expand Down Expand Up @@ -3217,44 +3219,72 @@ class IntervalIndex(GenericIndex):
def __init__(
self,
data,
closed=None,
closed: Optional[Literal["left", "right", "neither", "both"]] = None,
dtype=None,
copy=False,
copy: bool = False,
name=None,
):
if copy:
data = column.as_column(data, dtype=dtype).copy()
kwargs = _setdefault_name(data, name=name)

if closed is None:
closed = "right"
name = _setdefault_name(data, name=name)["name"]

if isinstance(data, IntervalColumn):
data = data
elif isinstance(data, pd.Series) and isinstance(
data.dtype, pd.IntervalDtype
):
data = column.as_column(data, data.dtype)
elif isinstance(data, (pd.Interval, pd.IntervalIndex)):
data = column.as_column(
data,
dtype=dtype,
)
elif len(data) == 0:
subtype = getattr(data, "dtype", "int64")
dtype = IntervalDtype(subtype, closed)
data = column.column_empty_like_same_mask(
column.as_column(data), dtype
if dtype is not None:
dtype = cudf.dtype(dtype)
if not isinstance(dtype, IntervalDtype):
raise TypeError("dtype must be an IntervalDtype")
if closed is not None and closed != dtype.closed:
raise ValueError("closed keyword does not match dtype.closed")
closed = dtype.closed

if closed is None and isinstance(dtype, IntervalDtype):
closed = dtype.closed

closed = closed or "right"

if len(data) == 0:
if not hasattr(data, "dtype"):
data = np.array([], dtype=np.int64)
elif isinstance(data.dtype, (pd.IntervalDtype, IntervalDtype)):
data = np.array([], dtype=data.dtype.subtype)
interval_col = IntervalColumn(
dtype=IntervalDtype(data.dtype, closed),
size=len(data),
children=(as_column(data), as_column(data)),
)
else:
data = column.as_column(data)
data.dtype.closed = closed
col = as_column(data)
if not isinstance(col, IntervalColumn):
raise TypeError("data must be an iterable of Interval data")
if copy:
col = col.copy()
interval_col = IntervalColumn(
dtype=IntervalDtype(col.dtype.subtype, closed),
mask=col.mask,
size=col.size,
offset=col.offset,
null_count=col.null_count,
children=col.children,
)

self.closed = closed
super().__init__(data, **kwargs)
if dtype:
interval_col = interval_col.astype(dtype) # type: ignore[assignment]

super().__init__(interval_col, name=name)

@property
def closed(self):
return self._values.dtype.closed

@classmethod
@_cudf_nvtx_annotate
def from_breaks(breaks, closed="right", name=None, copy=False, dtype=None):
def from_breaks(
cls,
breaks,
closed: Optional[
Literal["left", "right", "neither", "both"]
] = "right",
name=None,
copy: bool = False,
dtype=None,
):
"""
Construct an IntervalIndex from an array of splits.
Expand Down Expand Up @@ -3283,16 +3313,28 @@ def from_breaks(breaks, closed="right", name=None, copy=False, dtype=None):
>>> cudf.IntervalIndex.from_breaks([0, 1, 2, 3])
IntervalIndex([(0, 1], (1, 2], (2, 3]], dtype='interval[int64, right]')
"""
breaks = as_column(breaks, dtype=dtype)
if copy:
breaks = column.as_column(breaks, dtype=dtype).copy()
left_col = breaks[:-1:]
right_col = breaks[+1::]

interval_col = column.build_interval_column(
left_col, right_col, closed=closed
breaks = breaks.copy()
left_col = breaks.slice(0, len(breaks) - 1)
right_col = breaks.slice(1, len(breaks))
# For indexing, children should both have 0 offset
right_col = column.build_column(
data=right_col.data,
dtype=right_col.dtype,
size=right_col.size,
mask=right_col.mask,
offset=0,
null_count=right_col.null_count,
children=right_col.children,
)

return IntervalIndex(interval_col, name=name)
interval_col = IntervalColumn(
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
)
return IntervalIndex(interval_col, name=name, closed=closed)

def __getitem__(self, index):
raise NotImplementedError(
Expand Down
29 changes: 23 additions & 6 deletions python/cudf/cudf/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import numpy as np
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -57,11 +57,9 @@ def test_interval_range_dtype_basic(start_t, end_t):


@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
@pytest.mark.parametrize("start", [0])
@pytest.mark.parametrize("end", [0])
def test_interval_range_empty(start, end, closed):
pindex = pd.interval_range(start=start, end=end, closed=closed)
gindex = cudf.interval_range(start=start, end=end, closed=closed)
def test_interval_range_empty(closed):
pindex = pd.interval_range(start=0, end=0, closed=closed)
gindex = cudf.interval_range(start=0, end=0, closed=closed)

assert_eq(pindex, gindex)

Expand Down Expand Up @@ -315,3 +313,22 @@ def test_intervalindex_empty_typed_non_int():
result = cudf.IntervalIndex(data)
expected = pd.IntervalIndex(data)
assert_eq(result, expected)


def test_intervalindex_invalid_dtype():
with pytest.raises(TypeError):
cudf.IntervalIndex([pd.Interval(1, 2)], dtype="int64")


def test_intervalindex_conflicting_closed():
with pytest.raises(ValueError):
cudf.IntervalIndex(
[pd.Interval(1, 2)],
dtype=cudf.IntervalDtype("int64", closed="left"),
closed="right",
)


def test_intervalindex_invalid_data():
with pytest.raises(TypeError):
cudf.IntervalIndex([1, 2])
4 changes: 2 additions & 2 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
import math
import operator

Expand Down Expand Up @@ -636,7 +636,7 @@ def func(row):
["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1)
),
cudf.Series([1, 2, 3], dtype="category"),
cudf.interval_range(start=0, end=3, closed=True),
cudf.interval_range(start=0, end=3),
[[1, 2], [3, 4], [5, 6]],
[{"a": 1}, {"a": 2}, {"a": 3}],
],
Expand Down

0 comments on commit c83b9fd

Please sign in to comment.