Skip to content

Commit

Permalink
Showing 9 changed files with 624 additions and 10 deletions.
1 change: 1 addition & 0 deletions python/cudf/cudf/__init__.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
from_pandas,
interval_range,
merge,
cut,
)
from cudf.core.algorithms import factorize
from cudf.core.dtypes import (
1 change: 1 addition & 0 deletions python/cudf/cudf/core/__init__.py
Original file line number Diff line number Diff line change
@@ -27,3 +27,4 @@
from cudf.core.multiindex import MultiIndex
from cudf.core.scalar import NA, Scalar
from cudf.core.series import Series
from cudf.core.cut import cut
9 changes: 8 additions & 1 deletion python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@
is_mixed_with_object_dtype,
min_signed_type,
min_unsigned_type,
is_interval_dtype,
)

if TYPE_CHECKING:
@@ -1092,7 +1093,13 @@ def to_pandas(self, index: pd.Index = None, **kwargs) -> pd.Series:

signed_dtype = min_signed_type(len(col.categories))
codes = col.cat().codes.astype(signed_dtype).fillna(-1).to_array()
categories = col.categories.dropna(drop_nan=True).to_pandas()
if is_interval_dtype(col.categories.dtype):
# leaving out dropna because it temporarily changes an interval
# index into a struct and throws off results.
# TODO: work on interval index dropna
categories = col.categories.to_pandas()
else:
categories = col.categories.dropna(drop_nan=True).to_pandas()
data = pd.Categorical.from_codes(
codes, categories=categories, ordered=col.ordered
)
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
@@ -72,9 +72,10 @@ def to_arrow(self):
return pa.ExtensionArray.from_storage(typ, struct_arrow)

def from_struct_column(self, closed="right"):
first_field_name = list(self.dtype.fields.keys())[0]
return IntervalColumn(
size=self.size,
dtype=IntervalDtype(self.dtype.fields["left"], closed),
dtype=IntervalDtype(self.dtype.fields[first_field_name], closed),
mask=self.base_mask,
offset=self.offset,
null_count=self.null_count,
296 changes: 296 additions & 0 deletions python/cudf/cudf/core/cut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
from cudf._lib.labeling import label_bins
from cudf.core.column import as_column
from cudf.core.column import build_categorical_column
from cudf.core.index import IntervalIndex, interval_range
from cudf.utils.dtypes import is_list_like
import cupy
import cudf
import numpy as np
import pandas as pd
from collections.abc import Sequence

# from cudf._lib.filling import sequence


def cut(
x,
bins,
right: bool = True,
labels=None,
retbins: bool = False,
precision: int = 3,
include_lowest: bool = False,
duplicates: str = "raise",
ordered: bool = True,
):

"""
Bin values into discrete intervals.
Use cut when you need to segment and sort data values into bins. This
function is also useful for going from a continuous variable to a
categorical variable.
Parameters
----------
x : array-like
The input array to be binned. Must be 1-dimensional.
bins : int, sequence of scalars, or IntervalIndex
The criteria to bin by.
* int : Defines the number of equal-width bins in the
range of x. The range of x is extended by .1% on each
side to include the minimum and maximum values of x.
right : bool, default True
Indicates whether bins includes the rightmost edge or not.
labels : array or False, default None
Specifies the labels for the returned bins. Must be the same
length as the resulting bins. If False, returns only integer
indicators of thebins. If True,raises an error. When ordered=False,
labels must be provided.
retbins : bool, default False
Whether to return the bins or not.
precision : int, default 3
The precision at which to store and display the bins labels.
include_lowest : bool, default False
Whether the first interval should be left-inclusive or not.
duplicates : {default 'raise', 'drop'}, optional
If bin edges are not unique, raise ValueError or drop non-uniques.
ordered : bool, default True
Whether the labels are ordered or not. Applies to returned types
Categorical and Series (with Categorical dtype). If True,
the resulting categorical will be ordered. If False, the resulting
categorical will be unordered (labels must be provided).
Returns
-------
out : CategoricalIndex
An array-like object representing the respective bin for each value
of x. The type depends on the value of labels.
bins : numpy.ndarray or IntervalIndex.
The computed or specified bins. Only returned when retbins=True.
For scalar or sequence bins, this is an ndarray with the computed
bins. If set duplicates=drop, bins will drop non-unique bin. For
an IntervalIndex bins, this is equal to bins.
Examples
--------
Discretize into three equal-sized bins.
>>> cudf.cut(np.array([1, 7, 5, 4, 6, 3]), 3)
CategoricalIndex([(0.994, 3.0], (5.0, 7.0], (3.0, 5.0], (3.0, 5.0],
... (5.0, 7.0],(0.994, 3.0]], categories=[(0.994, 3.0],
... (3.0, 5.0], (5.0, 7.0]], ordered=True, dtype='category')
>>> cudf.cut(np.array([1, 7, 5, 4, 6, 3]), 3, retbins=True)
(CategoricalIndex([(0.994, 3.0], (5.0, 7.0], (3.0, 5.0], (3.0, 5.0],
... (5.0, 7.0],(0.994, 3.0]],categories=[(0.994, 3.0],
... (3.0, 5.0], (5.0, 7.0]],ordered=True, dtype='category'),
array([0.994, 3. , 5. , 7. ]))
>>> cudf.cut(np.array([1, 7, 5, 4, 6, 3]),
... 3, labels=["bad", "medium", "good"])
CategoricalIndex(['bad', 'good', 'medium', 'medium', 'good', 'bad'],
... categories=['bad', 'medium', 'good'],ordered=True,
... dtype='category')
>>> cudf.cut(np.array([1, 7, 5, 4, 6, 3]), 3,
... labels=["B", "A", "B"], ordered=False)
CategoricalIndex(['B', 'B', 'A', 'A', 'B', 'B'], categories=['A', 'B'],
... ordered=False, dtype='category')
>>> cudf.cut([0, 1, 1, 2], bins=4, labels=False)
array([0, 1, 1, 3], dtype=int32)
Passing a Series as an input returns a Series with categorical dtype:
>>> s = cudf.Series(np.array([2, 4, 6, 8, 10]),
... index=['a', 'b', 'c', 'd', 'e'])
>>> cudf.cut(s, 3)
"""
left_inclusive = False
right_inclusive = True
# saving the original input x for use in case its a series
orig_x = x
old_bins = bins

if not ordered and labels is None:
raise ValueError("'labels' must be provided if 'ordered = False'")

if duplicates not in ["raise", "drop"]:
raise ValueError(
"invalid value for 'duplicates' parameter, valid options are: "
"raise, drop"
)

if labels is not False:
if not (labels is None or is_list_like(labels)):
raise ValueError(
"Bin labels must either be False, None or passed in as a "
"list-like argument"
)
elif ordered and labels is not None:
if len(set(labels)) != len(labels):
raise ValueError(
"labels must be unique if ordered=True;"
"pass ordered=False for duplicate labels"
)

# bins can either be an int, sequence of scalars or an intervalIndex
if isinstance(bins, Sequence):
if len(set(bins)) is not len(bins):
if duplicates == "raise":
raise ValueError(
f"Bin edges must be unique: {repr(bins)}.\n"
f"You can drop duplicate edges by setting the 'duplicates'"
"kwarg"
)
elif duplicates == "drop":
# get unique values but maintain list dtype
bins = list(dict.fromkeys(bins))

# if bins is an intervalIndex we ignore the value of right
elif isinstance(bins, (pd.IntervalIndex, cudf.IntervalIndex)):
right = bins.closed == "right"

# create bins if given an int or single scalar
if not isinstance(bins, pd.IntervalIndex):
if not isinstance(bins, (Sequence)):
if isinstance(
x, (pd.Series, cudf.Series, np.ndarray, cupy.ndarray)
):
mn = x.min()
mx = x.max()
else:
mn = min(x)
mx = max(x)
bins = np.linspace(mn, mx, bins + 1, endpoint=True)
adj = (mx - mn) * 0.001
if right:
bins[0] -= adj
else:
bins[-1] += adj

# if right and include lowest we adjust the first
# bin edge to make sure it is included
if right and include_lowest:
bins[0] = bins[0] - 10 ** (-precision)

# if right is false the last bin edge is not included
if not right:
right_edge = bins[-1]
x = cupy.asarray(x)
x[x == right_edge] = right_edge + 1

# adjust bin edges decimal precision
int_label_bins = np.around(bins, precision)

# the inputs is a column of the values in the array x
input_arr = as_column(x)

# checking for the correct inclusivity values
if right:
closed = "right"
else:
closed = "left"
left_inclusive = True

if isinstance(bins, pd.IntervalIndex):
interval_labels = bins
elif labels is None:
if duplicates == "drop" and len(bins) == 1 and len(old_bins) != 1:
if right and include_lowest:
old_bins[0] = old_bins[0] - 10 ** (-precision)
interval_labels = interval_range(
old_bins[0], old_bins[1], periods=1, closed=closed
)
else:
interval_labels = IntervalIndex.from_breaks(
old_bins, closed=closed
)
else:
# get labels for categories
interval_labels = IntervalIndex.from_breaks(
int_label_bins, closed=closed
)
elif labels is not False:
if not (is_list_like(labels)):
raise ValueError(
"Bin labels must either be False, None or passed in as a "
"list-like argument"
)
if ordered and len(set(labels)) != len(labels):
raise ValueError(
"labels must be unique if ordered=True; pass ordered=False for"
"duplicate labels"
)
else:
if len(labels) != len(bins) - 1:
raise ValueError(
"Bin labels must be one fewer than the number of bin edges"
)
if not ordered and len(set(labels)) != len(labels):
interval_labels = cudf.CategoricalIndex(
labels, categories=None, ordered=False
)
else:
interval_labels = (
labels if len(set(labels)) == len(labels) else None
)

if isinstance(bins, pd.IntervalIndex):
# get the left and right edges of the bins as columns
# we cannot typecast an IntervalIndex, so we need to
# make the edges the same type as the input array
left_edges = as_column(bins.left).astype(input_arr.dtype)
right_edges = as_column(bins.right).astype(input_arr.dtype)
else:
# get the left and right edges of the bins as columns
left_edges = as_column(bins[:-1:], dtype="float64")
right_edges = as_column(bins[+1::], dtype="float64")
# the input arr must be changed to the same type as the edges
input_arr = input_arr.astype(left_edges.dtype)
# get the indexes for the appropriate number
index_labels = label_bins(
input_arr, left_edges, left_inclusive, right_edges, right_inclusive
)

if labels is False:
# if labels is false we return the index labels, we return them
# as a series if we have a series input
if isinstance(orig_x, (pd.Series, cudf.Series)):
# need to run more tests but looks like in this case pandas
# always returns a float64 dtype
indx_arr_series = cudf.Series(index_labels, dtype="float64")
# if retbins we return the bins as well
if retbins:
return indx_arr_series, bins
else:
return indx_arr_series
elif retbins:
return index_labels.values, bins
else:
return index_labels.values

if labels is not None:
if labels is not ordered and len(set(labels)) != len(labels):
# when we have duplicate labels and ordered is False, we
# should allow duplicate categories. The categories are
# returned in order
new_data = [interval_labels[i][0] for i in index_labels.values]
return cudf.CategoricalIndex(
new_data, categories=sorted(set(labels)), ordered=False
)

col = build_categorical_column(
categories=interval_labels,
codes=index_labels,
mask=index_labels.base_mask,
offset=index_labels.offset,
size=index_labels.size,
ordered=ordered,
)

# we return a categorical index, as we don't have a Categorical method
categorical_index = cudf.core.index.as_index(col)

if isinstance(orig_x, (pd.Series, cudf.Series)):
# if we have a series input we return a series output
res_series = cudf.Series(categorical_index, index=orig_x.index)
if retbins:
return res_series, bins
else:
return res_series
elif retbins:
# if retbins is true we return the bins as well
return categorical_index, bins
else:
return categorical_index
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
import pyarrow as pa
from pandas.api.extensions import ExtensionDtype
from pandas.core.arrays._arrow_utils import ArrowIntervalType
from cudf.utils.dtypes import is_interval_dtype

import cudf
from cudf._typing import Dtype
@@ -72,7 +73,7 @@ def to_pandas(self) -> pd.CategoricalDtype:
def _init_categories(self, categories: Any):
if categories is None:
return categories
if len(categories) == 0:
if len(categories) == 0 and not is_interval_dtype(categories):
dtype = "object" # type: Any
else:
dtype = None
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
@@ -480,7 +480,6 @@ def to_frame(self, index=True, name=None):
col_name = 0
else:
col_name = self.name

return cudf.DataFrame(
{col_name: self._values}, index=self if index else None
)
@@ -1080,7 +1079,6 @@ def to_series(self, index=None, name=None):
Series
The dtype will be based on the type of the Index values.
"""

return cudf.Series(
self._values,
index=self.copy(deep=False) if index is None else index,
@@ -2796,6 +2794,8 @@ def as_index(arbitrary, **kwargs) -> BaseIndex:
return TimedeltaIndex(arbitrary, **kwargs)
elif isinstance(arbitrary, CategoricalColumn):
return CategoricalIndex(arbitrary, **kwargs)
elif isinstance(arbitrary, IntervalColumn):
return IntervalIndex(arbitrary, **kwargs)
elif isinstance(arbitrary, cudf.Series):
return as_index(arbitrary._column, **kwargs)
elif isinstance(arbitrary, pd.RangeIndex):
16 changes: 11 additions & 5 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,9 @@
can_convert_to_column,
find_common_type,
is_decimal_dtype,
is_struct_dtype,
is_categorical_dtype,
is_interval_dtype,
is_list_dtype,
is_list_like,
is_mixed_with_object_dtype,
@@ -1222,15 +1225,12 @@ def __repr__(self):
if get_option("display.max_rows") == 0
else get_option("display.max_rows")
)

if len(self) > max_rows and max_rows != 0:
top = self.head(int(max_rows / 2 + 1))
bottom = self.tail(int(max_rows / 2 + 1))

preprocess = cudf.concat([top, bottom])
else:
preprocess = self.copy()

preprocess.index = preprocess.index._clean_nulls_from_index()
if (
preprocess.nullable
@@ -1271,6 +1271,14 @@ def __repr__(self):
)
)
else:
if is_categorical_dtype(self):
if is_interval_dtype(
self.dtype.categories
) and is_struct_dtype(preprocess._column.categories):
# with a series input the IntervalIndex's are turn
# into a struct dtype this line will change the
# categories back to an intervalIndex.
preprocess.dtype._categories = self.dtype.categories
pd_series = preprocess.to_pandas()
output = pd_series.to_string(
name=self.name,
@@ -2865,7 +2873,6 @@ def to_pandas(self, index=True, nullable=False, **kwargs):
3 30.0
dtype: float64
"""

if index is True:
index = self.index.to_pandas()
s = self._column.to_pandas(index=index, nullable=nullable)
@@ -5182,7 +5189,6 @@ def value_counts(
1.0 1
dtype: int32
"""

if bins is not None:
raise NotImplementedError("bins is not yet supported")

301 changes: 301 additions & 0 deletions python/cudf/cudf/tests/test_cut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

"""
Test related to Cut
"""

import pandas as pd
import numpy as np
from cudf.core.cut import cut
import pytest
from cudf.tests.utils import assert_eq


@pytest.mark.parametrize(
"x", [[1, 7, 5, 4, 6, 3], [1, 7], np.array([1, 7, 5, 4, 6, 3])]
)
@pytest.mark.parametrize("bins", [1, 2, 3])
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize(
"ordered", [True]
) # if ordered is False we need labels
@pytest.mark.parametrize("precision", [1, 2, 3])
def test_cut_basic(x, bins, right, include_lowest, ordered, precision):
# will test optional labels, retbins and duplicates seperately
# they need more specific parameters to work
pcat = pd.cut(
x=x,
bins=bins,
right=right,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)
pindex = pd.CategoricalIndex(pcat)
gindex = cut(
x=x,
bins=bins,
right=right,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(pindex, gindex)


@pytest.mark.parametrize("x", [[1, 7, 5, 4, 6, 3]])
@pytest.mark.parametrize("bins", [3]) # labels must be the same len as bins
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize(
"ordered", [True, False]
) # labels must be unique if ordered=True
@pytest.mark.parametrize("precision", [1, 2, 3])
@pytest.mark.parametrize(
"labels", [["bad", "medium", "good"], ["A", "B", "C"], [1, 2, 3], False]
)
def test_cut_labels(
x, bins, right, include_lowest, ordered, precision, labels
):

pcat = pd.cut(
x=x,
bins=bins,
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)
pindex = pd.CategoricalIndex(pcat) if labels else pcat
gindex = cut(
x=x,
bins=bins,
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(pindex, gindex)


@pytest.mark.parametrize("x", [[1, 7, 5, 4, 6, 3]])
@pytest.mark.parametrize("bins", [3]) # labels must be the same len as bins
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize(
"ordered", [False]
) # labels must be unique if ordered=True
@pytest.mark.parametrize("precision", [1, 2, 3])
@pytest.mark.parametrize(
"labels", [["bad", "good", "good"], ["B", "A", "B"], [1, 2, 2], False]
)
def test_cut_labels_non_unique(
x, bins, right, include_lowest, ordered, precision, labels
):

pcat = pd.cut(
x=x,
bins=bins,
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)
pindex = pd.CategoricalIndex(pcat) if labels else pcat
gindex = cut(
x=x,
bins=bins,
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(pindex, gindex)


@pytest.mark.parametrize(
"x",
[
[1, 7, 5, 4, 6, 3],
[1, 7],
np.array([1, 7, 5, 4, 6, 3]),
np.array([2, 4, 6, 8, 10]),
],
)
@pytest.mark.parametrize(
"bins", [1, 2, 3, [1, 2, 3], [0, 2, 4, 6, 10]],
)
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("precision", [3])
def test_cut_right(x, bins, right, precision):

pcat = pd.cut(x=x, bins=bins, right=right, precision=precision,)
pindex = pd.CategoricalIndex(pcat)
gindex = cut(x=x, bins=bins, right=right, precision=precision,)

assert_eq(pindex, gindex)


@pytest.mark.parametrize(
"x",
[
[1, 7, 5, 4, 6, 3],
[1, 7],
np.array([1, 7, 5, 4, 6, 3]),
np.array([2, 4, 6, 8, 10]),
],
)
@pytest.mark.parametrize(
"bins", [[0, 2, 4, 6, 10, 10], [1, 2, 2, 3, 3]],
)
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize("ordered", [True])
@pytest.mark.parametrize("precision", [1, 2, 3])
@pytest.mark.parametrize("duplicates", ["drop"])
def test_cut_drop_duplicates(
x, bins, right, precision, duplicates, ordered, include_lowest
):

pcat = pd.cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
include_lowest=include_lowest,
ordered=ordered,
)
pindex = pd.CategoricalIndex(pcat)
gindex = cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(pindex, gindex)


@pytest.mark.parametrize(
"x",
[
[1, 7, 5, 4, 6, 3],
[1, 7],
np.array([1, 7, 5, 4, 6, 3]),
np.array([2, 4, 6, 8, 10]),
],
)
@pytest.mark.parametrize(
"bins", [[0, 2, 4, 6, 10, 10], [1, 2, 2, 3, 3]],
)
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize("ordered", [True])
@pytest.mark.parametrize("precision", [1, 2, 3])
@pytest.mark.parametrize("duplicates", ["raises"])
def test_cut_drop_duplicates_raises(
x, bins, right, precision, duplicates, ordered, include_lowest
):
with pytest.raises(ValueError) as excgd:
cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
include_lowest=include_lowest,
ordered=ordered,
)
with pytest.raises(ValueError) as excpd:
pd.cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(str(excgd.value), str(excpd.value))


@pytest.mark.parametrize(
"x",
[
[0, 0.5, 1.5, 2.5, 4.5],
[1, 7, 5, 4, 6, 3],
[1, 7],
np.array([1, 7, 5, 4, 6, 3]),
np.array([2, 4, 6, 8, 10]),
],
)
@pytest.mark.parametrize(
"bins", [pd.IntervalIndex.from_tuples([(0, 1), (2, 3), (4, 5)])],
)
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("precision", [1, 2, 3])
@pytest.mark.parametrize("duplicates", ["drop", "raise"])
def test_cut_intervalindex_bin(x, bins, right, precision, duplicates):

pcat = pd.cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
)
pindex = pd.CategoricalIndex(pcat)
gindex = cut(
x=x,
bins=bins,
right=right,
precision=precision,
duplicates=duplicates,
)

assert_eq(pindex, gindex)


@pytest.mark.parametrize(
"x",
[pd.Series(np.array([2, 4, 6, 8, 10]), index=["a", "b", "c", "d", "e"])],
)
@pytest.mark.parametrize("bins", [1, 2, 3])
@pytest.mark.parametrize("right", [True, False])
@pytest.mark.parametrize("include_lowest", [True, False])
@pytest.mark.parametrize("ordered", [True])
@pytest.mark.parametrize("precision", [3])
def test_cut_series(x, bins, right, include_lowest, ordered, precision):

pcat = pd.cut(
x=x,
bins=bins,
right=right,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)

gcat = cut(
x=x,
bins=bins,
right=right,
precision=precision,
include_lowest=include_lowest,
ordered=ordered,
)

assert_eq(pcat, gcat)

0 comments on commit 58b354f

Please sign in to comment.