From c83b9fdcf45aa0b7204ef0313dc0a778dc15e017 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 23 Jan 2024 06:19:34 -1000 Subject: [PATCH] Refactor and add validation to IntervalIndex.__init__ (#14778) * Adding validation to `closed`, `dtype` arguments in `ItervalIndex.__init__` * Ensure `closed` attribute always maps to `IntervalDtype.closed` * `build_interval_column` was no longer necessary by using `IntervalColumn` directly Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Bradley Dice (https://github.com/bdice) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/14778 --- python/cudf/cudf/core/column/column.py | 50 +------- python/cudf/cudf/core/column/interval.py | 25 +--- python/cudf/cudf/core/index.py | 120 ++++++++++++------ .../cudf/cudf/tests/indexes/test_interval.py | 29 ++++- python/cudf/cudf/tests/test_udf_masked_ops.py | 4 +- 5 files changed, 114 insertions(+), 114 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 7a99ef9f470..dc060a7117e 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -999,14 +999,14 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase: "`.astype('str')` instead." ) return col.as_string_column(dtype) + elif isinstance(dtype, IntervalDtype): + return col.as_interval_column(dtype) elif isinstance(dtype, (ListDtype, StructDtype)): if not col.dtype == dtype: raise NotImplementedError( f"Casting {self.dtype} columns not currently supported" ) return col - elif isinstance(dtype, IntervalDtype): - return col.as_interval_column(dtype) elif isinstance(dtype, cudf.core.dtypes.DecimalDtype): return col.as_decimal_column(dtype) elif np.issubdtype(cast(Any, dtype), np.datetime64): @@ -1689,52 +1689,6 @@ def build_categorical_column( return cast("cudf.core.column.CategoricalColumn", result) -def build_interval_column( - left_col, - right_col, - mask=None, - size=None, - offset=0, - null_count=None, - closed="right", -): - """ - Build an IntervalColumn - - Parameters - ---------- - left_col : Column - Column of values representing the left of the interval - right_col : Column - Column of representing the right of the interval - mask : Buffer - Null mask - size : int, optional - offset : int, optional - closed : {"left", "right", "both", "neither"}, default "right" - Whether the intervals are closed on the left-side, right-side, - both or neither. - """ - left = as_column(left_col) - right = as_column(right_col) - if closed not in {"left", "right", "both", "neither"}: - closed = "right" - if type(left_col) is not list: - dtype = IntervalDtype(left_col.dtype, closed) - else: - dtype = IntervalDtype("int64", closed) - size = len(left) - return build_column( - data=None, - dtype=dtype, - mask=mask, - size=size, - offset=offset, - null_count=null_count, - children=(left, right), - ) - - def build_list_column( indices: ColumnBase, elements: ColumnBase, diff --git a/python/cudf/cudf/core/column/interval.py b/python/cudf/cudf/core/column/interval.py index 6a7e7729123..7227ef8ba3a 100644 --- a/python/cudf/cudf/core/column/interval.py +++ b/python/cudf/cudf/core/column/interval.py @@ -18,7 +18,6 @@ def __init__( offset=0, null_count=None, children=(), - closed="right", ): super().__init__( data=None, @@ -29,14 +28,6 @@ def __init__( null_count=null_count, children=children, ) - if closed in ["left", "right", "neither", "both"]: - self._closed = closed - else: - raise ValueError("closed value is not valid") - - @property - def closed(self): - return self._closed @classmethod def from_arrow(cls, data): @@ -50,7 +41,6 @@ def from_arrow(cls, data): offset = data.offset null_count = data.null_count children = new_col.children - closed = dtype.closed return IntervalColumn( size=size, @@ -59,7 +49,6 @@ def from_arrow(cls, data): offset=offset, null_count=null_count, children=children, - closed=closed, ) def to_arrow(self): @@ -73,7 +62,7 @@ def to_arrow(self): @classmethod def from_struct_column(cls, struct_column: StructColumn, closed="right"): - first_field_name = list(struct_column.dtype.fields.keys())[0] + first_field_name = next(iter(struct_column.dtype.fields.keys())) return IntervalColumn( size=struct_column.size, dtype=IntervalDtype( @@ -83,20 +72,19 @@ def from_struct_column(cls, struct_column: StructColumn, closed="right"): offset=struct_column.offset, null_count=struct_column.null_count, children=struct_column.base_children, - closed=closed, ) def copy(self, deep=True): - closed = self.closed struct_copy = super().copy(deep=deep) return IntervalColumn( size=struct_copy.size, - dtype=IntervalDtype(struct_copy.dtype.fields["left"], closed), + dtype=IntervalDtype( + struct_copy.dtype.fields["left"], self.dtype.closed + ), mask=struct_copy.base_mask, offset=struct_copy.offset, null_count=struct_copy.null_count, children=struct_copy.base_children, - closed=closed, ) def as_interval_column(self, dtype): @@ -109,7 +97,7 @@ def as_interval_column(self, dtype): # when creating an interval series or interval dataframe if dtype == "interval": dtype = IntervalDtype( - self.dtype.fields["left"], self.closed + self.dtype.subtype, self.dtype.closed ) children = self.children return IntervalColumn( @@ -119,7 +107,6 @@ def as_interval_column(self, dtype): offset=self.offset, null_count=self.null_count, children=children, - closed=dtype.closed, ) else: raise ValueError("dtype must be IntervalDtype") @@ -141,5 +128,5 @@ def to_pandas( def element_indexing(self, index: int): result = super().element_indexing(index) if cudf.get_option("mode.pandas_compatible"): - return pd.Interval(**result, closed=self._closed) + return pd.Interval(**result, closed=self.dtype.closed) return result diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index fa7173f1d0f..c10124f4de6 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -3174,10 +3174,12 @@ def interval_range( data = column.column_empty_like_same_mask(left_col, dtype) return IntervalIndex(data, closed=closed) - interval_col = column.build_interval_column( - left_col, right_col, closed=closed + interval_col = IntervalColumn( + dtype=IntervalDtype(left_col.dtype, closed), + size=len(left_col), + children=(left_col, right_col), ) - return IntervalIndex(interval_col) + return IntervalIndex(interval_col, closed=closed) class IntervalIndex(GenericIndex): @@ -3217,44 +3219,72 @@ class IntervalIndex(GenericIndex): def __init__( self, data, - closed=None, + closed: Optional[Literal["left", "right", "neither", "both"]] = None, dtype=None, - copy=False, + copy: bool = False, name=None, ): - if copy: - data = column.as_column(data, dtype=dtype).copy() - kwargs = _setdefault_name(data, name=name) - - if closed is None: - closed = "right" + name = _setdefault_name(data, name=name)["name"] - if isinstance(data, IntervalColumn): - data = data - elif isinstance(data, pd.Series) and isinstance( - data.dtype, pd.IntervalDtype - ): - data = column.as_column(data, data.dtype) - elif isinstance(data, (pd.Interval, pd.IntervalIndex)): - data = column.as_column( - data, - dtype=dtype, - ) - elif len(data) == 0: - subtype = getattr(data, "dtype", "int64") - dtype = IntervalDtype(subtype, closed) - data = column.column_empty_like_same_mask( - column.as_column(data), dtype + if dtype is not None: + dtype = cudf.dtype(dtype) + if not isinstance(dtype, IntervalDtype): + raise TypeError("dtype must be an IntervalDtype") + if closed is not None and closed != dtype.closed: + raise ValueError("closed keyword does not match dtype.closed") + closed = dtype.closed + + if closed is None and isinstance(dtype, IntervalDtype): + closed = dtype.closed + + closed = closed or "right" + + if len(data) == 0: + if not hasattr(data, "dtype"): + data = np.array([], dtype=np.int64) + elif isinstance(data.dtype, (pd.IntervalDtype, IntervalDtype)): + data = np.array([], dtype=data.dtype.subtype) + interval_col = IntervalColumn( + dtype=IntervalDtype(data.dtype, closed), + size=len(data), + children=(as_column(data), as_column(data)), ) else: - data = column.as_column(data) - data.dtype.closed = closed + col = as_column(data) + if not isinstance(col, IntervalColumn): + raise TypeError("data must be an iterable of Interval data") + if copy: + col = col.copy() + interval_col = IntervalColumn( + dtype=IntervalDtype(col.dtype.subtype, closed), + mask=col.mask, + size=col.size, + offset=col.offset, + null_count=col.null_count, + children=col.children, + ) - self.closed = closed - super().__init__(data, **kwargs) + if dtype: + interval_col = interval_col.astype(dtype) # type: ignore[assignment] + super().__init__(interval_col, name=name) + + @property + def closed(self): + return self._values.dtype.closed + + @classmethod @_cudf_nvtx_annotate - def from_breaks(breaks, closed="right", name=None, copy=False, dtype=None): + def from_breaks( + cls, + breaks, + closed: Optional[ + Literal["left", "right", "neither", "both"] + ] = "right", + name=None, + copy: bool = False, + dtype=None, + ): """ Construct an IntervalIndex from an array of splits. @@ -3283,16 +3313,28 @@ def from_breaks(breaks, closed="right", name=None, copy=False, dtype=None): >>> cudf.IntervalIndex.from_breaks([0, 1, 2, 3]) IntervalIndex([(0, 1], (1, 2], (2, 3]], dtype='interval[int64, right]') """ + breaks = as_column(breaks, dtype=dtype) if copy: - breaks = column.as_column(breaks, dtype=dtype).copy() - left_col = breaks[:-1:] - right_col = breaks[+1::] - - interval_col = column.build_interval_column( - left_col, right_col, closed=closed + breaks = breaks.copy() + left_col = breaks.slice(0, len(breaks) - 1) + right_col = breaks.slice(1, len(breaks)) + # For indexing, children should both have 0 offset + right_col = column.build_column( + data=right_col.data, + dtype=right_col.dtype, + size=right_col.size, + mask=right_col.mask, + offset=0, + null_count=right_col.null_count, + children=right_col.children, ) - return IntervalIndex(interval_col, name=name) + interval_col = IntervalColumn( + dtype=IntervalDtype(left_col.dtype, closed), + size=len(left_col), + children=(left_col, right_col), + ) + return IntervalIndex(interval_col, name=name, closed=closed) def __getitem__(self, index): raise NotImplementedError( diff --git a/python/cudf/cudf/tests/indexes/test_interval.py b/python/cudf/cudf/tests/indexes/test_interval.py index 52c49aebf35..5a6155ece29 100644 --- a/python/cudf/cudf/tests/indexes/test_interval.py +++ b/python/cudf/cudf/tests/indexes/test_interval.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. import numpy as np import pandas as pd import pyarrow as pa @@ -57,11 +57,9 @@ def test_interval_range_dtype_basic(start_t, end_t): @pytest.mark.parametrize("closed", ["left", "right", "both", "neither"]) -@pytest.mark.parametrize("start", [0]) -@pytest.mark.parametrize("end", [0]) -def test_interval_range_empty(start, end, closed): - pindex = pd.interval_range(start=start, end=end, closed=closed) - gindex = cudf.interval_range(start=start, end=end, closed=closed) +def test_interval_range_empty(closed): + pindex = pd.interval_range(start=0, end=0, closed=closed) + gindex = cudf.interval_range(start=0, end=0, closed=closed) assert_eq(pindex, gindex) @@ -315,3 +313,22 @@ def test_intervalindex_empty_typed_non_int(): result = cudf.IntervalIndex(data) expected = pd.IntervalIndex(data) assert_eq(result, expected) + + +def test_intervalindex_invalid_dtype(): + with pytest.raises(TypeError): + cudf.IntervalIndex([pd.Interval(1, 2)], dtype="int64") + + +def test_intervalindex_conflicting_closed(): + with pytest.raises(ValueError): + cudf.IntervalIndex( + [pd.Interval(1, 2)], + dtype=cudf.IntervalDtype("int64", closed="left"), + closed="right", + ) + + +def test_intervalindex_invalid_data(): + with pytest.raises(TypeError): + cudf.IntervalIndex([1, 2]) diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index ad0c961a749..11970944a95 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. import math import operator @@ -636,7 +636,7 @@ def func(row): ["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1) ), cudf.Series([1, 2, 3], dtype="category"), - cudf.interval_range(start=0, end=3, closed=True), + cudf.interval_range(start=0, end=3), [[1, 2], [3, 4], [5, 6]], [{"a": 1}, {"a": 2}, {"a": 3}], ],