diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index 37cf9655ef7..6e23e3d18b7 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -1432,7 +1432,94 @@ def rename(self, name, inplace=False): out.name = name return out - def get_slice_bound(self, label, side, kind=None): + def _indices_of(self, value) -> cudf.core.column.NumericalColumn: + """ + Return indices corresponding to value + + Parameters + ---------- + value + Value to look for in index + + Returns + ------- + Column of indices + """ + raise NotImplementedError + + def find_label_range(self, loc: slice) -> slice: + """ + Translate a label-based slice to an index-based slice + + Parameters + ---------- + loc + slice to search for. + + Notes + ----- + As with all label-based searches, the slice is right-closed. + + Returns + ------- + New slice translated into integer indices of the index (right-open). + """ + start = loc.start + stop = loc.stop + step = 1 if loc.step is None else loc.step + if step < 0: + start_side, stop_side = "right", "left" + else: + start_side, stop_side = "left", "right" + istart = ( + None + if start is None + else self.get_slice_bound(start, side=start_side) + ) + istop = ( + None + if stop is None + else self.get_slice_bound(stop, side=stop_side) + ) + if step < 0: + # Fencepost + istart = None if istart is None else max(istart - 1, 0) + istop = None if (istop is None or istop == 0) else istop - 1 + return slice(istart, istop, step) + + def searchsorted( + self, + value, + side: str = "left", + ascending: bool = True, + na_position: str = "last", + ): + """Find index where elements should be inserted to maintain order + + Parameters + ---------- + value : + Value to be hypothetically inserted into Self + side : str {'left', 'right'} optional, default 'left' + If 'left', the index of the first suitable location found is given + If 'right', return the last such index + ascending : bool optional, default True + Index is in ascending order (otherwise descending) + na_position : str {'last', 'first'} optional, default 'last' + Position of null values in sorted order + + Returns + ------- + Insertion point. + + Notes + ----- + As a precondition the index must be sorted in the same order + as requested by the `ascending` flag. + """ + raise NotImplementedError() + + def get_slice_bound(self, label, side: str, kind=None) -> int: """ Calculate slice bound that corresponds to given label. Returns leftmost (one-past-the-rightmost if ``side=='right'``) position @@ -1449,7 +1536,31 @@ def get_slice_bound(self, label, side, kind=None): int Index of label. """ - raise NotImplementedError + if kind is not None: + warnings.warn( + "'kind' argument in get_slice_bound is deprecated and will be " + "removed in a future version.", + FutureWarning, + ) + if side not in {"left", "right"}: + raise ValueError(f"Invalid side argument {side}") + if self.is_monotonic_increasing or self.is_monotonic_decreasing: + return self.searchsorted( + label, side=side, ascending=self.is_monotonic_increasing + ) + else: + try: + left, right = self._values._find_first_and_last(label) + except ValueError: + raise KeyError(f"{label=} not in index") + if left != right: + raise KeyError( + f"Cannot get slice bound for non-unique label {label=}" + ) + if side == "left": + return left + else: + return right + 1 def __array_function__(self, func, types, args, kwargs): # check if the function is implemented for the current type diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index 66e37fe57f0..f10af419cb7 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -1286,19 +1286,10 @@ def fillna( return result - def find_first_value( - self, value: ScalarLike, closest: bool = False - ) -> int: - """ - Returns offset of first value that matches - """ - return self.as_numerical.find_first_value(self._encode(value)) - - def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: - """ - Returns offset of last value that matches - """ - return self.as_numerical.find_last_value(self._encode(value)) + def indices_of( + self, value: ScalarLike + ) -> cudf.core.column.NumericalColumn: + return self.as_numerical.indices_of(self._encode(value)) @property def is_monotonic_increasing(self) -> bool: diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 4eee24017fa..15a71266cc1 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -45,6 +45,7 @@ drop_nulls, ) from cudf._lib.transform import bools_to_mask +from cudf._lib.types import size_type_dtype from cudf._typing import ColumnLike, Dtype, ScalarLike from cudf.api.types import ( _is_non_decimal_numeric_dtype, @@ -734,29 +735,79 @@ def notnull(self) -> ColumnBase: return result - def find_first_value( - self, value: ScalarLike, closest: bool = False - ) -> int: + def indices_of( + self, value: ScalarLike | Self + ) -> cudf.core.column.NumericalColumn: """ - Returns offset of first value that matches + Find locations of value in the column + + Parameters + ---------- + value + Scalar to look for (cast to dtype of column), or a length-1 column + + Returns + ------- + Column of indices that match value """ - # FIXME: Inefficient, may be need a libcudf api - index = cudf.core.index.RangeIndex(0, stop=len(self)) - indices = index.take(self == value) - if not len(indices): - raise ValueError("value not found") - return indices[0] - - def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: + if not isinstance(value, ColumnBase): + value = as_column([value], dtype=self.dtype) + else: + assert len(value) == 1 + mask = libcudf.search.contains(value, self) + return apply_boolean_mask( + [arange(0, len(self), dtype=size_type_dtype)], mask + )[0] + + def _find_first_and_last(self, value: ScalarLike) -> Tuple[int, int]: + indices = self.indices_of(value) + if n := len(indices): + return ( + indices.element_indexing(0), + indices.element_indexing(n - 1), + ) + else: + raise ValueError(f"Value {value} not found in column") + + def find_first_value(self, value: ScalarLike) -> int: """ - Returns offset of last value that matches + Return index of first value that matches + + Parameters + ---------- + value + Value to search for (cast to dtype of column) + + Returns + ------- + Index of value + + Raises + ------ + ValueError if value is not found + """ + first, _ = self._find_first_and_last(value) + return first + + def find_last_value(self, value: ScalarLike) -> int: """ - # FIXME: Inefficient, may be need a libcudf api - index = cudf.core.index.RangeIndex(0, stop=len(self)) - indices = index.take(self == value) - if not len(indices): - raise ValueError("value not found") - return indices[-1] + Return index of last value that matches + + Parameters + ---------- + value + Value to search for (cast to dtype of column) + + Returns + ------- + Index of value + + Raises + ------ + ValueError if value is not found + """ + _, last = self._find_first_and_last(value) + return last def append(self, other: ColumnBase) -> ColumnBase: return concat_columns([self, as_column(other)]) @@ -893,39 +944,6 @@ def is_monotonic_decreasing(self) -> bool: ascending=[False], null_position=None ) - def get_slice_bound(self, label: ScalarLike, side: str, kind: str) -> int: - """ - Calculate slice bound that corresponds to given label. - Returns leftmost (one-past-the-rightmost if ``side=='right'``) position - of given label. - - Parameters - ---------- - label : Scalar - side : {'left', 'right'} - kind : {'ix', 'loc', 'getitem'} - """ - if kind not in {"ix", "loc", "getitem", None}: - raise ValueError( - f"Invalid value for ``kind`` parameter," - f" must be either one of the following: " - f"{'ix', 'loc', 'getitem', None}, but found: {kind}" - ) - if side not in {"left", "right"}: - raise ValueError( - "Invalid value for side kwarg," - " must be either 'left' or 'right': %s" % (side,) - ) - - # TODO: Handle errors/missing keys correctly - # Not currently using `kind` argument. - if side == "left": - return self.find_first_value(label, closest=True) - elif side == "right": - return self.find_last_value(label, closest=True) + 1 - else: - raise ValueError(f"Invalid value for side: {side}") - def sort_values( self: ColumnBase, ascending: bool = True, diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index d63b9adecbc..ccb91e85ecc 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -470,27 +470,13 @@ def fillna( return super().fillna(fill_value, method) - def find_first_value( - self, value: ScalarLike, closest: bool = False - ) -> int: - """ - Returns offset of first value that matches - """ - value = pd.to_datetime(value) - value = column.as_column( - value, dtype=self.dtype - ).as_numerical.element_indexing(0) - return self.as_numerical.find_first_value(value, closest=closest) - - def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: - """ - Returns offset of last value that matches - """ - value = pd.to_datetime(value) + def indices_of( + self, value: ScalarLike + ) -> cudf.core.column.NumericalColumn: value = column.as_column( - value, dtype=self.dtype - ).as_numerical.element_indexing(0) - return self.as_numerical.find_last_value(value, closest=closest) + pd.to_datetime(value), dtype=self.dtype + ).as_numerical + return self.as_numerical.indices_of(value) @property def is_unique(self) -> bool: diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index c03794d5e5e..50795c22b82 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -13,12 +13,14 @@ cast, ) +import cupy as cp import numpy as np import pandas as pd import cudf from cudf import _lib as libcudf from cudf._lib.stream_compaction import drop_nulls +from cudf._lib.types import size_type_dtype from cudf._typing import ( ColumnBinaryOperand, ColumnLike, @@ -31,14 +33,9 @@ is_float_dtype, is_integer, is_integer_dtype, - is_number, is_scalar, ) -from cudf.core.buffer import ( - Buffer, - acquire_spill_lock, - cuda_array_interface_wrapper, -) +from cudf.core.buffer import Buffer, cuda_array_interface_wrapper from cudf.core.column import ( ColumnBase, as_column, @@ -49,14 +46,12 @@ ) from cudf.core.dtypes import CategoricalDtype from cudf.core.mixins import BinaryOperand -from cudf.utils import cudautils from cudf.utils.dtypes import ( NUMERIC_TYPES, min_column_type, min_signed_type, np_dtypes_to_pandas_dtypes, numeric_normalize_types, - to_cudf_compatible_scalar, ) from .numerical_base import NumericalBaseColumn @@ -124,6 +119,19 @@ def __contains__(self, item: ScalarLike) -> bool: self, column.as_column([item], dtype=self.dtype) ).any() + def indices_of(self, value: ScalarLike) -> NumericalColumn: + if ( + value is not None + and self.dtype.kind in {"c", "f"} + and np.isnan(value) + ): + return column.as_column( + cp.argwhere(cp.isnan(self.data_array_view(mode="read"))), + dtype=size_type_dtype, + ) + else: + return super().indices_of(value) + def has_nulls(self, include_nan=False): return self.null_count != 0 or ( self.nan_count != 0 if include_nan else False @@ -567,62 +575,6 @@ def fillna( return super(NumericalColumn, col).fillna(fill_value, method) - @acquire_spill_lock() - def _find_value( - self, value: ScalarLike, closest: bool, find: Callable, compare: str - ) -> int: - value = to_cudf_compatible_scalar(value) - if not is_number(value): - raise ValueError("Expected a numeric value") - found = 0 - if len(self): - found = find( - self.data_array_view(mode="read"), - value, - mask=self.mask, - ) - if found == -1: - if self.is_monotonic_increasing and closest: - found = find( - self.data_array_view(mode="read"), - value, - mask=self.mask, - compare=compare, - ) - if found == -1: - raise ValueError("value not found") - else: - raise ValueError("value not found") - return found - - def find_first_value( - self, value: ScalarLike, closest: bool = False - ) -> int: - """ - Returns offset of first value that matches. For monotonic - columns, returns the offset of the first larger value - if closest=True. - """ - if self.is_monotonic_increasing and closest: - if value < self.min(): - return 0 - elif value > self.max(): - return len(self) - return self._find_value(value, closest, cudautils.find_first, "gt") - - def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: - """ - Returns offset of last value that matches. For monotonic - columns, returns the offset of the last smaller value - if closest=True. - """ - if self.is_monotonic_increasing and closest: - if value < self.min(): - return -1 - elif value > self.max(): - return len(self) - 1 - return self._find_value(value, closest, cudautils.find_last, "lt") - def can_cast_safely(self, to_dtype: DtypeObj) -> bool: """ Returns true if all the values in self can be diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 9319881669f..3978e22691c 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5717,23 +5717,6 @@ def fillna( else: return super().fillna(method=method) - def _find_first_and_last(self, value: ScalarLike) -> Tuple[int, int]: - found_indices = libcudf.search.contains( - column.as_column([value], dtype=self.dtype), self - ) - found_indices = libcudf.unary.cast(found_indices, dtype=np.int32) - first = column.as_column(found_indices).find_first_value(np.int32(1)) - last = column.as_column(found_indices).find_last_value(np.int32(1)) - return first, last - - def find_first_value( - self, value: ScalarLike, closest: bool = False - ) -> int: - return self._find_first_and_last(value)[0] - - def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: - return self._find_first_and_last(value)[1] - def normalize_binop_value( self, other ) -> Union[column.ColumnBase, cudf.Scalar]: diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 3ae0566182d..1389f6a20e3 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -27,6 +27,7 @@ from cudf._lib.datetime import extract_quarter, is_leap_year from cudf._lib.filling import sequence from cudf._lib.search import search_sorted +from cudf._lib.types import size_type_dtype from cudf.api.types import ( _is_non_decimal_numeric_dtype, is_categorical_dtype, @@ -200,6 +201,18 @@ def _copy_type_metadata( # have an underlying column. return self + def searchsorted( + self, + value: int, + side: str = "left", + ascending: bool = True, + na_position: str = "last", + ): + assert (len(self) <= 1) or ( + ascending == (self._step > 0) + ), "Invalid ascending flag" + return search_range(value, self.as_range, side=side) + @property # type: ignore @_cudf_nvtx_annotate def name(self): @@ -457,46 +470,6 @@ def dtype(self): dtype = np.dtype(np.int64) return _maybe_convert_to_default_type(dtype) - @_cudf_nvtx_annotate - def find_label_range(self, first=None, last=None): - """Find subrange in the ``RangeIndex``, marked by their positions, that - starts greater or equal to ``first`` and ends less or equal to ``last`` - - The range returned is assumed to be monotonically increasing. In cases - where there is no such range that suffice the constraint, an exception - will be raised. - - Parameters - ---------- - first, last : int, optional, Default None - The "start" and "stop" values of the subrange. If None, will use - ``self._start`` as first, ``self._stop`` as last. - - Returns - ------- - begin, end : 2-tuple of int - The starting index and the ending index. - The `last` value occurs at ``end - 1`` position. - """ - - first = self._start if first is None else first - last = self._stop if last is None else last - - if self._step < 0: - first = -first - last = -last - start = -self._start - step = -self._step - else: - start = self._start - step = self._step - - stop = start + len(self) * step - begin = search_range(start, stop, first, step, side="left") - end = search_range(start, stop, last, step, side="right") - - return begin, end - @_cudf_nvtx_annotate def to_pandas(self, nullable=False): return pd.RangeIndex( @@ -514,57 +487,20 @@ def is_unique(self): """ return True - @property # type: ignore + @cached_property + def as_range(self): + return range(self._start, self._stop, self._step) + + @cached_property # type: ignore @_cudf_nvtx_annotate def is_monotonic_increasing(self): return self._step > 0 or len(self) <= 1 - @property # type: ignore + @cached_property # type: ignore @_cudf_nvtx_annotate def is_monotonic_decreasing(self): return self._step < 0 or len(self) <= 1 - @_cudf_nvtx_annotate - def get_slice_bound(self, label, side, kind=None): - """ - Calculate slice bound that corresponds to given label. - Returns leftmost (one-past-the-rightmost if ``side=='right'``) position - of given label. - - Parameters - ---------- - label : int - A valid value in the ``RangeIndex`` - side : {'left', 'right'} - kind : Unused - To keep consistency with other index types. - - Returns - ------- - int - Index of label. - """ - if kind is not None: - warnings.warn( - "'kind' argument in get_slice_bound is deprecated and will be " - "removed in a future version.", - FutureWarning, - ) - if side not in {"left", "right"}: - raise ValueError(f"Unrecognized side parameter: {side}") - - if self._step < 0: - label = -label - start = -self._start - step = -self._step - else: - start = self._start - step = self._step - - stop = start + len(self) * step - pos = search_range(start, stop, label, step, side=side) - return pos - @_cudf_nvtx_annotate def memory_usage(self, deep=False): if deep: @@ -954,6 +890,13 @@ def any(self): def append(self, other): return self._as_int_index().append(other) + def _indices_of(self, value) -> cudf.core.column.NumericalColumn: + try: + i = [range(self._start, self._stop, self._step).index(value)] + except ValueError: + i = [] + return as_column(i, dtype=size_type_dtype) + def isin(self, values): if is_scalar(values): raise TypeError( @@ -1429,26 +1372,6 @@ def dtype(self): """ return self._values.dtype - @_cudf_nvtx_annotate - def find_label_range(self, first, last): - """Find range that starts with *first* and ends with *last*, - inclusively. - - Returns - ------- - begin, end : 2-tuple of int - The starting index and the ending index. - The *last* value occurs at ``end - 1`` position. - """ - col = self._values - begin, end = None, None - if first is not None: - begin = col.find_first_value(first, closest=True) - if last is not None: - end = col.find_last_value(last, closest=True) - end += 1 - return begin, end - @_cudf_nvtx_annotate def isna(self): return self._column.isnull().values @@ -1461,16 +1384,6 @@ def notna(self): notnull = notna - @_cudf_nvtx_annotate - def get_slice_bound(self, label, side, kind=None): - if kind is not None: - warnings.warn( - "'kind' argument in get_slice_bound is deprecated and will be " - "removed in a future version.", - FutureWarning, - ) - return self._values.get_slice_bound(label, side, kind) - def _is_numeric(self): return False @@ -1626,6 +1539,10 @@ def isin(self, values): return self._values.isin(values).values + def _indices_of(self, value): + """Return indices of value in index""" + return self._column.indices_of(value) + class NumericIndex(GenericIndex): """Immutable, ordered and sliceable sequence of labels. @@ -2045,6 +1962,18 @@ def __init__( super().__init__(data, **kwargs) + def searchsorted( + self, + value, + side: str = "left", + ascending: bool = True, + na_position: str = "last", + ): + value = self.dtype.type(value) + return super().searchsorted( + value, side=side, ascending=ascending, na_position=na_position + ) + @property # type: ignore @_cudf_nvtx_annotate def year(self): @@ -3365,6 +3294,14 @@ def from_arrow(cls, obj): # Try interpreting object as a MultiIndex before failing. return cudf.MultiIndex.from_arrow(obj) + @cached_property + def is_monotonic_increasing(self): + return super().is_monotonic_increasing + + @cached_property + def is_monotonic_decreasing(self): + return super().is_monotonic_decreasing + @_cudf_nvtx_annotate def _concat_range_index(indexes: List[RangeIndex]) -> BaseIndex: diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 3128f766748..2d07fa23adb 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -213,8 +213,7 @@ def _get_label_range_or_mask(index, start, stop, step): boolean_mask = index <= stop return boolean_mask else: - start, stop = index.find_label_range(start, stop) - return slice(start, stop, step) + return index.find_label_range(slice(start, stop, step)) class _FrameIndexer: @@ -418,11 +417,7 @@ def _scan(self, op, axis=None, skipna=True): result_col = col else: if col.has_nulls(include_nan=True): - # Workaround as find_first_value doesn't seem to work - # in case of bools. - first_index = int( - col.isnull().astype("int8").find_first_value(1) - ) + first_index = col.isnull().find_first_value(True) result_col = col.copy() result_col[first_index:] = None else: diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index d8a08647f37..0083635bad1 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -560,6 +560,9 @@ def codes(self): self._compute_levels_and_codes() return self._codes + def get_slice_bound(self, label, side, kind=None): + raise NotImplementedError() + @property # type: ignore @_cudf_nvtx_annotate def nlevels(self): @@ -1510,7 +1513,7 @@ def from_pandas(cls, multiindex, nan_as_null=None): def is_unique(self): return len(self) == len(self.unique()) - @property # type: ignore + @cached_property # type: ignore @_cudf_nvtx_annotate def is_monotonic_increasing(self): """ @@ -1519,7 +1522,7 @@ def is_monotonic_increasing(self): """ return self._is_sorted(ascending=None, null_position=None) - @property # type: ignore + @cached_property # type: ignore @_cudf_nvtx_annotate def is_monotonic_decreasing(self): """ diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index b5392fcbe62..3de9985b759 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -307,12 +307,15 @@ def _loc_to_iloc(self, arg): found_index = arg return found_index try: - found_index = self._frame.index._values.find_first_value( - arg, closest=False - ) - return found_index + indices = self._frame.index._indices_of(arg) + if (n := len(indices)) == 0: + raise KeyError("Label scalar is out of bounds") + elif n == 1: + return indices.element_indexing(0) + else: + return indices except (TypeError, KeyError, IndexError, ValueError): - raise KeyError("label scalar is out of bound") + raise KeyError("Label scalar is out of bounds") elif isinstance(arg, slice): return _get_label_range_or_mask( diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index dde9ad3053d..03b4e76871c 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -75,23 +75,23 @@ def test_df_slice_empty_index(): def test_index_find_label_range_genericindex(): # Monotonic Index idx = cudf.Index(np.asarray([4, 5, 6, 10])) - assert idx.find_label_range(4, 6) == (0, 3) - assert idx.find_label_range(5, 10) == (1, 4) - assert idx.find_label_range(0, 6) == (0, 3) - assert idx.find_label_range(4, 11) == (0, 4) + assert idx.find_label_range(slice(4, 6)) == slice(0, 3, 1) + assert idx.find_label_range(slice(5, 10)) == slice(1, 4, 1) + assert idx.find_label_range(slice(0, 6)) == slice(0, 3, 1) + assert idx.find_label_range(slice(4, 11)) == slice(0, 4, 1) # Non-monotonic Index idx_nm = cudf.Index(np.asarray([5, 4, 6, 10])) - assert idx_nm.find_label_range(4, 6) == (1, 3) - assert idx_nm.find_label_range(5, 10) == (0, 4) + assert idx_nm.find_label_range(slice(4, 6)) == slice(1, 3, 1) + assert idx_nm.find_label_range(slice(5, 10)) == slice(0, 4, 1) # Last value not found - with pytest.raises(ValueError) as raises: - idx_nm.find_label_range(0, 6) - raises.match("value not found") + with pytest.raises(KeyError) as raises: + idx_nm.find_label_range(slice(0, 6)) + raises.match("not in index") # Last value not found - with pytest.raises(ValueError) as raises: - idx_nm.find_label_range(4, 11) - raises.match("value not found") + with pytest.raises(KeyError) as raises: + idx_nm.find_label_range(slice(4, 11)) + raises.match("not in index") def test_index_find_label_range_rangeindex(): @@ -99,18 +99,19 @@ def test_index_find_label_range_rangeindex(): # step > 0 # 3, 8, 13, 18 ridx = RangeIndex(3, 20, 5) - assert ridx.find_label_range(3, 8) == (0, 2) - assert ridx.find_label_range(0, 7) == (0, 1) - assert ridx.find_label_range(3, 19) == (0, 4) - assert ridx.find_label_range(2, 21) == (0, 4) + assert ridx.find_label_range(slice(3, 8)) == slice(0, 2, 1) + assert ridx.find_label_range(slice(0, 7)) == slice(0, 1, 1) + assert ridx.find_label_range(slice(3, 19)) == slice(0, 4, 1) + assert ridx.find_label_range(slice(2, 21)) == slice(0, 4, 1) # step < 0 # 20, 15, 10, 5 ridx = RangeIndex(20, 3, -5) - assert ridx.find_label_range(15, 10) == (1, 3) - assert ridx.find_label_range(10, 0) == (2, 4) - assert ridx.find_label_range(30, 13) == (0, 2) - assert ridx.find_label_range(30, 0) == (0, 4) + assert ridx.find_label_range(slice(15, 10)) == slice(1, 3, 1) + assert ridx.find_label_range(slice(10, 15, -1)) == slice(2, 0, -1) + assert ridx.find_label_range(slice(10, 0)) == slice(2, 4, 1) + assert ridx.find_label_range(slice(30, 13)) == slice(0, 2, 1) + assert ridx.find_label_range(slice(30, 0)) == slice(0, 4, 1) def test_index_comparision(): @@ -1329,7 +1330,6 @@ def test_float_index_apis(data, name, dtype): @pytest.mark.parametrize("ordered", [True, False]) @pytest.mark.parametrize("name", [1, "a", None]) def test_categorical_index_basic(data, categories, dtype, ordered, name): - # can't have both dtype and categories/ordered if dtype is not None: categories = None @@ -1586,7 +1586,6 @@ def test_interval_index_empty(closed): ], ) def test_interval_index_many_params(data, closed): - pindex = pd.IntervalIndex(data, closed=closed) gindex = IntervalIndex(data, closed=closed) @@ -1843,14 +1842,10 @@ def test_index_equals_categories(): def test_index_rangeindex_search_range(): # step > 0 ridx = RangeIndex(-13, 17, 4) - stop = ridx._start + ridx._step * len(ridx) + ri = ridx.as_range for i in range(len(ridx)): - assert i == search_range( - ridx._start, stop, ridx[i], ridx._step, side="left" - ) - assert i + 1 == search_range( - ridx._start, stop, ridx[i], ridx._step, side="right" - ) + assert i == search_range(ridx[i], ri, side="left") + assert i + 1 == search_range(ridx[i], ri, side="right") @pytest.mark.parametrize( @@ -2335,7 +2330,6 @@ def test_union_index(idx1, idx2, sort): ) @pytest.mark.parametrize("sort", [None, False]) def test_intersection_index(idx1, idx2, sort): - expected = idx1.intersection(idx2, sort=sort) idx1 = cudf.from_pandas(idx1) if isinstance(idx1, pd.Index) else idx1 diff --git a/python/cudf/cudf/tests/test_indexing.py b/python/cudf/cudf/tests/test_indexing.py index d9bb01ca794..bf280ed7844 100644 --- a/python/cudf/cudf/tests/test_indexing.py +++ b/python/cudf/cudf/tests/test_indexing.py @@ -1819,7 +1819,6 @@ def test_loc_multiindex_timestamp_issue_8585(index_type): assert_eq(expect, actual) -@pytest.mark.xfail(reason="https://github.com/rapidsai/cudf/issues/8693") def test_loc_repeated_index_label_issue_8693(): # https://github.com/rapidsai/cudf/issues/8693 s = pd.Series([1, 2, 3, 4], index=[0, 1, 1, 2]) @@ -1968,7 +1967,6 @@ def test_iloc_multiindex_lookup_as_label_issue_13515(indexer): assert_eq(expect, actual) -@pytest.mark.xfail(reason="https://github.com/rapidsai/cudf/issues/12833") def test_loc_unsorted_index_slice_lookup_keyerror_issue_12833(): # https://github.com/rapidsai/cudf/issues/12833 df = pd.DataFrame({"a": [1, 2, 3]}, index=[7, 0, 4]) @@ -2006,7 +2004,7 @@ def order(self, request): def take_order(self, request): return request.param - @pytest.fixture(params=["float", "int", "string"]) + @pytest.fixture(params=["float", "int", "string", "range"]) def dtype(self, request): return request.param @@ -2018,6 +2016,13 @@ def index(self, order, dtype): index = [-1, 10, 7, 14] elif dtype == "float": index = [-1.5, 7.10, 2.4, 11.2] + elif dtype == "range": + if order == "increasing": + return cudf.RangeIndex(2, 10, 3) + elif order == "decreasing": + return cudf.RangeIndex(10, 1, -3) + else: + return cudf.RangeIndex(10, 20, 3) else: raise ValueError(f"Unhandled index dtype {dtype}") if order == "decreasing": @@ -2051,12 +2056,6 @@ def test_loc_index_inindex_subset(self, df, take_order): def test_loc_index_notinindex_slice( self, request, df, order, dtype, take_order ): - if not (order == "increasing" and dtype in {"int", "float"}): - request.applymarker( - pytest.mark.xfail( - reason="https://github.com/rapidsai/cudf/issues/12833" - ) - ) pdf = df.to_pandas() lo = pdf.index[1] hi = pdf.index[-2] @@ -2066,7 +2065,7 @@ def test_loc_index_notinindex_slice( else: lo -= 1 hi += 1 - if order == "neither": + if order == "neither" and dtype != "range": with pytest.raises(KeyError): pdf.loc[lo:hi:take_order] with pytest.raises(KeyError): diff --git a/python/cudf/cudf/tests/test_monotonic.py b/python/cudf/cudf/tests/test_monotonic.py index 887d61aa152..37529973c8f 100644 --- a/python/cudf/cudf/tests/test_monotonic.py +++ b/python/cudf/cudf/tests/test_monotonic.py @@ -20,7 +20,6 @@ @pytest.mark.parametrize("testrange", [(10, 20, 1), (0, -10, -1), (5, 5, 1)]) def test_range_index(testrange): - index = RangeIndex( start=testrange[0], stop=testrange[1], step=testrange[2] ) @@ -52,7 +51,6 @@ def test_range_index(testrange): ], ) def test_generic_index(testlist): - index = GenericIndex(testlist) index_pd = pd.Index(testlist) @@ -76,7 +74,6 @@ def test_generic_index(testlist): ], ) def test_string_index(testlist): - index = cudf.Index(testlist) index_pd = pd.Index(testlist) @@ -94,7 +91,6 @@ def test_string_index(testlist): "testlist", [["c", "d", "e", "f"], ["z", "y", "x", "r"]] ) def test_categorical_index(testlist): - # Assuming unordered categorical data cannot be "monotonic" raw_cat = pd.Categorical(testlist, ordered=True) index = CategoricalIndex(raw_cat) @@ -141,7 +137,6 @@ def test_categorical_index(testlist): ], ) def test_datetime_index(testlist): - index = DatetimeIndex(testlist) index_pd = pd.DatetimeIndex(testlist) @@ -328,12 +323,9 @@ def test_get_slice_bound_missing(label, side, kind): assert got == expect -@pytest.mark.xfail @pytest.mark.parametrize("label", ["a", "c", "e", "g"]) @pytest.mark.parametrize("side", ["left", "right"]) def test_get_slice_bound_missing_str(label, side): - # Slicing for monotonic string indices not yet supported - # when missing values are specified (allowed in pandas) mylist = ["b", "d", "f"] index = GenericIndex(mylist) index_pd = pd.Index(mylist) diff --git a/python/cudf/cudf/utils/cudautils.py b/python/cudf/cudf/utils/cudautils.py index 32f4837baad..020c32de9f3 100755 --- a/python/cudf/cudf/utils/cudautils.py +++ b/python/cudf/cudf/utils/cudautils.py @@ -3,12 +3,9 @@ from pickle import dumps import cachetools -import numpy as np from numba import cuda from numba.np import numpy_support -import cudf -from cudf._lib.unary import is_non_nan from cudf.utils._numba import _CUDFNumbaConfig # @@ -16,81 +13,6 @@ # -# Find segments -def find_index_of_val(arr, val, mask=None, compare="eq"): - """ - Returns the indices of the occurrence of *val* in *arr* - as per *compare*, if not found it will be filled with - size of *arr* - - Parameters - ---------- - arr : device array - val : scalar - mask : mask of the array - compare: str ('gt', 'lt', or 'eq' (default)) - """ - arr = cudf.core.column.as_column(arr, nan_as_null=False) - if arr.size > 0: - if compare == "gt": - locations = arr <= val - elif compare == "lt": - locations = arr >= val - else: - locations = is_non_nan(arr) if np.isnan(val) else arr != val - - target = cudf.core.column.column.arange(0, arr.size, dtype="int32") - found = cudf._lib.copying._boolean_mask_scatter_scalar( - [cudf.Scalar(arr.size, dtype="int32").device_value], - [target], - locations, - )[0] - else: - found = cudf.core.column.column.column_empty(0, dtype="int32") - return cudf.core.column.column.as_column(found).set_mask(mask) - - -def find_first(arr, val, mask=None, compare="eq"): - """ - Returns the index of the first occurrence of *val* in *arr*.. - Or the first occurrence of *arr* *compare* *val*, if *compare* is not eq - Otherwise, returns -1. - - Parameters - ---------- - arr : device array - val : scalar - mask : mask of the array - compare: str ('gt', 'lt', or 'eq' (default)) - """ - found_col = find_index_of_val(arr, val, mask=mask, compare=compare) - found_col = found_col.find_and_replace([arr.size], [None], True) - - min_index = found_col.min() - return -1 if min_index is None or np.isnan(min_index) else min_index - - -def find_last(arr, val, mask=None, compare="eq"): - """ - Returns the index of the last occurrence of *val* in *arr*. - Or the last occurrence of *arr* *compare* *val*, if *compare* is not eq - Otherwise, returns -1. - - Parameters - ---------- - arr : device array - val : scalar - mask : mask of the array - compare: str ('gt', 'lt', or 'eq' (default)) - """ - - found_col = find_index_of_val(arr, val, mask=mask, compare=compare) - found_col = found_col.find_and_replace([arr.size], [None], True) - - max_index = found_col.max() - return -1 if max_index is None or np.isnan(max_index) else max_index - - @cuda.jit def gpu_window_sizes_from_offset(arr, window_sizes, offset): i = cuda.grid(1) diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index 5fbf91b49e9..e317085e63c 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -284,50 +284,63 @@ def _fillna_natwise(col): ) -def search_range(start, stop, x, step=1, side="left"): - """Find the position to insert a value in a range, so that the resulting - sequence remains sorted. - - When ``side`` is set to 'left', the insertion point ``i`` will hold the - following invariant: - `all(x < n for x in range_left) and all(x >= n for x in range_right)` - where ``range_left`` and ``range_right`` refers to the range to the left - and right of position ``i``, respectively. +def search_range(x: int, ri: range, *, side: str) -> int: + """ - When ``side`` is set to 'right', ``i`` will hold the following invariant: - `all(x <= n for x in range_left) and all(x > n for x in range_right)` + Find insertion point in a range to maintain sorted order Parameters ---------- - start : int - Start value of the series - stop : int - Stop value of the range - x : int - The value to insert - step : int, default 1 - Step value of the series, assumed positive - side : {'left', 'right'}, default 'left' - See description for usage. + x + Integer to insert + ri + Range to insert into + side + Tie-breaking decision for the case that `x` is a member of the + range. If `"left"` then the insertion point is before the + entry, otherwise it is after. Returns ------- int - Insertion position of n. + The insertion point + + See Also + -------- + numpy.searchsorted + + Notes + ----- + Let ``p`` be the return value, then if ``side="left"`` the + following invariants are maintained:: + + all(x < n for n in ri[:p]) + all(x >= n for n in ri[p:]) + + Conversely, if ``side="right"`` then we have:: + + all(x <= n for n in ri[:p]) + all(x > n for n in ri[p:]) Examples -------- For series: 1 4 7 - >>> search_range(start=1, stop=10, x=4, step=3, side="left") + >>> search_range(4, range(1, 10, 3), side="left") 1 - >>> search_range(start=1, stop=10, x=4, step=3, side="right") + >>> search_range(4, range(1, 10, 3), side="right") 2 """ - z = 1 if side == "left" else 0 - i = (x - start - z) // step + 1 + assert side in {"left", "right"} + if flip := (ri.step < 0): + ri = ri[::-1] + shift = int(side == "right") + else: + shift = int(side == "left") - length = (stop - start) // step - return max(min(length, i), 0) + offset = (x - ri.start - shift) // ri.step + 1 + if flip: + offset = len(ri) - offset + return max(min(len(ri), offset), 0) def _get_color_for_nvtx(name):