diff --git a/python/cudf/cudf/_lib/cpp/stream_compaction.pxd b/python/cudf/cudf/_lib/cpp/stream_compaction.pxd index 61efd040807..bba2d1ffb7c 100644 --- a/python/cudf/cudf/_lib/cpp/stream_compaction.pxd +++ b/python/cudf/cudf/_lib/cpp/stream_compaction.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -20,6 +20,7 @@ from cudf._lib.cpp.types cimport ( cdef extern from "cudf/stream_compaction.hpp" namespace "cudf" \ nogil: ctypedef enum duplicate_keep_option: + KEEP_ANY 'cudf::duplicate_keep_option::KEEP_ANY' KEEP_FIRST 'cudf::duplicate_keep_option::KEEP_FIRST' KEEP_LAST 'cudf::duplicate_keep_option::KEEP_LAST' KEEP_NONE 'cudf::duplicate_keep_option::KEEP_NONE' @@ -33,13 +34,14 @@ cdef extern from "cudf/stream_compaction.hpp" namespace "cudf" \ column_view boolean_mask ) except + - cdef unique_ptr[table] unique( - table_view source_table, - vector[size_type] keys, - duplicate_keep_option keep, - null_equality nulls_equal) except + - cdef size_type distinct_count( column_view source_table, null_policy null_handling, nan_policy nan_handling) except + + + cdef unique_ptr[table] stable_distinct( + table_view input, + vector[size_type] keys, + duplicate_keep_option keep, + null_equality nulls_equal, + ) except + diff --git a/python/cudf/cudf/_lib/stream_compaction.pyx b/python/cudf/cudf/_lib/stream_compaction.pyx index 143999e52ef..4422ad83885 100644 --- a/python/cudf/cudf/_lib/stream_compaction.pyx +++ b/python/cudf/cudf/_lib/stream_compaction.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock @@ -9,22 +9,19 @@ from libcpp.vector cimport vector from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.sorting cimport stable_sort_by_key as cpp_stable_sort_by_key from cudf._lib.cpp.stream_compaction cimport ( apply_boolean_mask as cpp_apply_boolean_mask, distinct_count as cpp_distinct_count, drop_nulls as cpp_drop_nulls, duplicate_keep_option, - unique as cpp_unique, + stable_distinct as cpp_stable_distinct, ) from cudf._lib.cpp.table.table cimport table from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.cpp.types cimport ( nan_policy, null_equality, - null_order, null_policy, - order, size_type, ) from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns @@ -145,41 +142,13 @@ def drop_duplicates(list columns, if nulls_are_equal else null_equality.UNEQUAL ) - - cdef vector[order] column_order = ( - vector[order]( - cpp_keys.size(), - order.ASCENDING - ) - ) - cdef vector[null_order] null_precedence = ( - vector[null_order]( - cpp_keys.size(), - null_order.BEFORE - ) - ) - cdef table_view source_table_view = table_view_from_columns(columns) - cdef table_view keys_view = source_table_view.select(cpp_keys) - cdef unique_ptr[table] sorted_source_table cdef unique_ptr[table] c_result with nogil: - # cudf::unique keeps unique rows in each consecutive group of - # equivalent rows. To match the behavior of pandas.DataFrame. - # drop_duplicates, users need to stable sort the input first - # and then invoke cudf::unique. - sorted_source_table = move( - cpp_stable_sort_by_key( - source_table_view, - keys_view, - column_order, - null_precedence - ) - ) c_result = move( - cpp_unique( - sorted_source_table.get().view(), + cpp_stable_distinct( + source_table_view, cpp_keys, cpp_keep_option, cpp_nulls_equal diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index a2e3bc44f3a..1fe30179001 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -1476,7 +1476,12 @@ def __array_function__(self, func, types, args, kwargs): if cudf_func is func: return NotImplemented else: - return cudf_func(*args, **kwargs) + result = cudf_func(*args, **kwargs) + if fname == "unique": + # NumPy expects a sorted result for `unique`, which is not + # guaranteed by cudf.Index.unique. + result = result.sort_values() + return result else: return NotImplemented diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index 39332807139..d28851f4ace 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -1042,8 +1042,15 @@ def data_array_view( ) -> cuda.devicearray.DeviceNDArray: return self.codes.data_array_view(mode=mode) - def unique(self, preserve_order=False) -> CategoricalColumn: - codes = self.as_numerical.unique(preserve_order=preserve_order) + def unique(self, preserve_order=True) -> CategoricalColumn: + if preserve_order is not True: + warnings.warn( + "The preserve_order argument is deprecated. It will be " + "removed in a future version. As of now, unique always " + "preserves order regardless of the argument's value.", + FutureWarning, + ) + codes = self.as_numerical.unique() return column.build_categorical_column( categories=self.categories, codes=column.build_column(codes.base_data, dtype=codes.dtype), @@ -1397,9 +1404,7 @@ def _concat( head = next((obj for obj in objs if obj.valid_count), objs[0]) # Combine and de-dupe the categories - cats = column.concat_columns([o.categories for o in objs]).unique( - preserve_order=True - ) + cats = column.concat_columns([o.categories for o in objs]).unique() objs = [o._set_categories(cats, is_unique=True) for o in objs] codes = [o.codes for o in objs] @@ -1538,10 +1543,7 @@ def _set_categories( # Ensure new_categories is unique first if not (is_unique or new_cats.is_unique): - # drop_duplicates() instead of unique() to preserve order - new_cats = cudf.Series(new_cats)._column.unique( - preserve_order=True - ) + new_cats = cudf.Series(new_cats)._column.unique() cur_codes = self.codes max_cat_size = ( diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 607bf83ff6c..255ac2582af 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1021,17 +1021,16 @@ def as_categorical_column(self, dtype, **kwargs) -> ColumnBase: ordered=dtype.ordered, ) - cats = self.unique().astype(self.dtype) + # Categories must be unique and sorted in ascending order. + cats = self.unique().sort_by_values()[0].astype(self.dtype) label_dtype = min_unsigned_type(len(cats)) labels = self._label_encoding( cats=cats, dtype=label_dtype, na_sentinel=cudf.Scalar(1) ) - # columns include null index in factorization; remove: if self.has_nulls(): cats = cats.dropna(drop_nan=False) min_type = min_unsigned_type(len(cats), 8) - labels = labels - 1 if cudf.dtype(min_type).itemsize < labels.dtype.itemsize: labels = labels.astype(min_type) @@ -1132,25 +1131,17 @@ def searchsorted( values, side, ascending=ascending, na_position=na_position ) - def unique(self, preserve_order=False) -> ColumnBase: + def unique(self, preserve_order=True) -> ColumnBase: """ Get unique values in the data """ - # TODO: We could avoid performing `drop_duplicates` for - # columns with values that already are unique. - # Few things to note before we can do this optimization is - # the following issue resolved: - # https://github.com/rapidsai/cudf/issues/5286 - if preserve_order: - ind = as_column(cupy.arange(0, len(self))) - - # dedup based on the column of data only - ind, col = drop_duplicates([ind, self], keys=[1]) - - # sort col based on ind - map = ind.argsort() - return col.take(map) - + if preserve_order is not True: + warnings.warn( + "The preserve_order argument is deprecated. It will be " + "removed in a future version. As of now, unique always " + "preserves order regardless of the argument's value.", + FutureWarning, + ) return drop_duplicates([self], keep="first")[0] def serialize(self) -> Tuple[dict, list]: diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index eb6685861d4..5fc4870105b 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -3165,34 +3165,46 @@ def diff(self, periods=1, axis=0): @_cudf_nvtx_annotate def drop_duplicates( - self, subset=None, keep="first", inplace=False, ignore_index=False + self, + subset=None, + keep="first", + inplace=False, + ignore_index=False, ): """ - Return DataFrame with duplicate rows removed, optionally only - considering certain subset of columns. + Return DataFrame with duplicate rows removed. + + Considering certain columns is optional. Indexes, including time + indexes are ignored. Parameters ---------- subset : column label or sequence of labels, optional Only consider certain columns for identifying duplicates, by default use all of the columns. - keep : {'first', 'last', False}, default 'first' + keep : {'first', 'last', ``False``}, default 'first' Determines which duplicates (if any) to keep. - - ``first`` : Drop duplicates except for the first occurrence. - - ``last`` : Drop duplicates except for the last occurrence. - - False : Drop all duplicates. - inplace : bool, default False + - 'first' : Drop duplicates except for the first occurrence. + - 'last' : Drop duplicates except for the last occurrence. + - ``False`` : Drop all duplicates. + inplace : bool, default ``False`` Whether to drop duplicates in place or to return a copy. - ignore_index : bool, default False - If True, the resulting axis will be labeled 0, 1, …, n - 1. + ignore_index : bool, default ``False`` + If True, the resulting axis will be labeled 0, 1, ..., n - 1. Returns ------- DataFrame or None DataFrame with duplicates removed or None if ``inplace=True``. + See Also + -------- + DataFrame.value_counts: Count unique combinations of columns. + Examples -------- + Consider a dataset containing ramen ratings. + >>> import cudf >>> df = cudf.DataFrame({ ... 'brand': ['Yum Yum', 'Yum Yum', 'Indomie', 'Indomie', 'Indomie'], @@ -3207,36 +3219,34 @@ def drop_duplicates( 3 Indomie pack 15.0 4 Indomie pack 5.0 - By default, it removes duplicate rows based - on all columns. Note that order of - the rows being returned is not guaranteed - to be sorted. + By default, it removes duplicate rows based on all columns. >>> df.drop_duplicates() brand style rating + 0 Yum Yum cup 4.0 2 Indomie cup 3.5 - 4 Indomie pack 5.0 3 Indomie pack 15.0 - 0 Yum Yum cup 4.0 + 4 Indomie pack 5.0 - To remove duplicates on specific column(s), - use `subset`. + To remove duplicates on specific column(s), use ``subset``. >>> df.drop_duplicates(subset=['brand']) brand style rating - 2 Indomie cup 3.5 0 Yum Yum cup 4.0 + 2 Indomie cup 3.5 - To remove duplicates and keep last occurrences, use `keep`. + To remove duplicates and keep last occurrences, use ``keep``. >>> df.drop_duplicates(subset=['brand', 'style'], keep='last') brand style rating + 1 Yum Yum cup 4.0 2 Indomie cup 3.5 4 Indomie pack 5.0 - 1 Yum Yum cup 4.0 """ # noqa: E501 outdf = super().drop_duplicates( - subset=subset, keep=keep, ignore_index=ignore_index + subset=subset, + keep=keep, + ignore_index=ignore_index, ) return self._mimic_inplace(outdf, inplace=inplace) @@ -7693,7 +7703,7 @@ def _find_common_dtypes_and_categories(non_null_columns, dtypes): # Combine and de-dupe the categories categories[idx] = cudf.Series( concat_columns([col.categories for col in cols]) - )._column.unique(preserve_order=True) + )._column.unique() # Set the column dtype to the codes' dtype. The categories # will be re-assigned at the end dtypes[idx] = min_scalar_type(len(categories[idx])) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index fe22cc48e0f..b7faed1dfc3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -308,7 +308,7 @@ def dtypes(self): 2 object int64 3 object int64 """ - index = self.grouping.keys.unique().to_pandas() + index = self.grouping.keys.unique().sort_values().to_pandas() return pd.DataFrame( { name: [self.obj._dtypes[name]] * len(index) @@ -864,25 +864,27 @@ def ngroup(self, ascending=True): 5 0 dtype: int64 """ - num_groups = len(index := self.grouping.keys.unique()) + index = self.grouping.keys.unique().sort_values() + num_groups = len(index) _, has_null_group = bitmask_or([*index._columns]) if ascending: - if has_null_group: - group_ids = cudf.Series._from_data( - {None: cp.arange(-1, num_groups - 1)} - ) - else: - group_ids = cudf.Series._from_data( - {None: cp.arange(num_groups)} - ) + # Count ascending from 0 to num_groups - 1 + group_ids = cudf.Series._from_data({None: cp.arange(num_groups)}) + elif has_null_group: + # Count descending from num_groups - 1 to 0, but subtract one more + # for the null group making it num_groups - 2 to -1. + group_ids = cudf.Series._from_data( + {None: cp.arange(num_groups - 2, -2, -1)} + ) else: + # Count descending from num_groups - 1 to 0 group_ids = cudf.Series._from_data( {None: cp.arange(num_groups - 1, -1, -1)} ) if has_null_group: - group_ids.iloc[0] = cudf.NA + group_ids.iloc[-1] = cudf.NA group_ids._index = index return self._broadcast(group_ids) @@ -1065,7 +1067,7 @@ def _grouped(self): column_names=self.obj._column_names, index_names=self.obj._index_names, ) - group_names = grouped_keys.unique() + group_names = grouped_keys.unique().sort_values() return (group_names, offsets, grouped_keys, grouped_values) def _normalize_aggs( diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index 4b784ac7b20..2055ecc96a0 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -1138,7 +1138,7 @@ def _get_unique(column, dummy_na): if isinstance(column, cudf.core.column.CategoricalColumn): unique = column.categories else: - unique = column.unique() + unique = column.unique().sort_by_values()[0] if not dummy_na: if np.issubdtype(unique.dtype, np.floating): unique = unique.nans_to_nulls() diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 0e8481dd820..a99eda6bd0b 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -1151,7 +1151,12 @@ def __array_function__(self, func, types, args, kwargs): try: # Apply a Series method if one exists. if cudf_func := getattr(Series, func.__name__, None): - return cudf_func(*args, **kwargs) + result = cudf_func(*args, **kwargs) + if func.__name__ == "unique": + # NumPy expects a sorted result for `unique`, which is not + # guaranteed by cudf.Series.unique. + result = result.sort_values() + return result # Assume that cupy subpackages match numpy and search the # corresponding cupy submodule based on the func's __module__. @@ -1718,20 +1723,20 @@ def drop_duplicates(self, keep="first", inplace=False, ignore_index=False): to be sorted. >>> s.drop_duplicates() - 3 beetle + 0 lama 1 cow + 3 beetle 5 hippo - 0 lama Name: animal, dtype: object The value 'last' for parameter `keep` keeps the last occurrence for each set of duplicated entries. >>> s.drop_duplicates(keep='last') - 3 beetle 1 cow - 5 hippo + 3 beetle 4 lama + 5 hippo Name: animal, dtype: object The value `False` for parameter `keep` discards all sets @@ -1740,8 +1745,8 @@ def drop_duplicates(self, keep="first", inplace=False, ignore_index=False): >>> s.drop_duplicates(keep=False, inplace=True) >>> s - 3 beetle 1 cow + 3 beetle 5 hippo Name: animal, dtype: object """ @@ -2887,9 +2892,9 @@ def unique(self): 6 c dtype: object >>> series.unique() - 0 - 1 a - 2 b + 0 a + 1 b + 2 3 c dtype: object """ diff --git a/python/cudf/cudf/tests/test_array_function.py b/python/cudf/cudf/tests/test_array_function.py index 65874c94b93..a355ebb40b2 100644 --- a/python/cudf/cudf/tests/test_array_function.py +++ b/python/cudf/cudf/tests/test_array_function.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import numpy as np import pandas as pd import pytest @@ -94,15 +94,26 @@ def test_array_func_missing_cudf_dataframe(pd_df, func): func(cudf_df) -# we only implement sum among all numpy non-ufuncs @pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason) @pytest.mark.parametrize("np_ar", [np.random.random(100)]) -@pytest.mark.parametrize("func", [lambda x: np.sum(x), lambda x: np.dot(x, x)]) +@pytest.mark.parametrize( + "func", + [ + lambda x: np.mean(x), + lambda x: np.sum(x), + lambda x: np.var(x, ddof=1), + lambda x: np.unique(x), + lambda x: np.dot(x, x), + ], +) def test_array_func_cudf_index(np_ar, func): cudf_index = cudf.core.index.as_index(cudf.Series(np_ar)) expect = func(np_ar) got = func(cudf_index) - assert_eq(expect, got) + if np.isscalar(expect): + assert_eq(expect, got) + else: + assert_eq(expect, got.to_numpy()) @pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason) diff --git a/python/cudf/cudf/tests/test_onehot.py b/python/cudf/cudf/tests/test_onehot.py index e5ca2e028c3..6d5bfde7740 100644 --- a/python/cudf/cudf/tests/test_onehot.py +++ b/python/cudf/cudf/tests/test_onehot.py @@ -8,7 +8,7 @@ import cudf from cudf import DataFrame -from cudf.testing import _utils as utils +from cudf.testing._utils import assert_eq pytestmark = pytest.mark.spilling @@ -31,14 +31,14 @@ def test_get_dummies(data, index): with pytest.warns(FutureWarning): encoded_actual = cudf.get_dummies(gdf, prefix="test") - utils.assert_eq( + assert_eq( encoded_expected, encoded_actual, check_dtype=len(data) != 0, ) encoded_actual = cudf.get_dummies(gdf, prefix="test", dtype=np.uint8) - utils.assert_eq( + assert_eq( encoded_expected, encoded_actual, check_dtype=len(data) != 0, @@ -59,7 +59,7 @@ def test_onehot_get_dummies_multicol(n_cols): with pytest.warns(FutureWarning): encoded_actual = cudf.get_dummies(gdf, prefix="test") - utils.assert_eq(encoded_expected, encoded_actual) + assert_eq(encoded_expected, encoded_actual) @pytest.mark.parametrize("nan_as_null", [True, False]) @@ -75,7 +75,7 @@ def test_onehost_get_dummies_dummy_na(nan_as_null, dummy_na): if dummy_na and nan_as_null: got = got.rename(columns={"a_null": "a_nan"})[expected.columns] - utils.assert_eq(expected, got) + assert_eq(expected, got) @pytest.mark.parametrize( @@ -115,7 +115,7 @@ def test_get_dummies_prefix_sep(prefix, prefix_sep): gdf, prefix=prefix, prefix_sep=prefix_sep ) - utils.assert_eq(encoded_expected, encoded_actual) + assert_eq(encoded_expected, encoded_actual) def test_get_dummies_with_nan(): @@ -124,55 +124,55 @@ def test_get_dummies_with_nan(): ) expected = cudf.DataFrame( { - "a_null": [0, 0, 0, 1], "a_1.0": [1, 0, 0, 0], "a_2.0": [0, 1, 0, 0], "a_nan": [0, 0, 1, 0], + "a_null": [0, 0, 0, 1], }, dtype="uint8", ) with pytest.warns(FutureWarning): actual = cudf.get_dummies(df, dummy_na=True, columns=["a"]) - utils.assert_eq(expected, actual) + assert_eq(expected, actual) @pytest.mark.parametrize( "data", [ - cudf.Series(["abc", "l", "a", "abc", "z", "xyz"]), - cudf.Index([None, 1, 2, 3.3, None, 0.2]), - cudf.Series([0.1, 2, 3, None, np.nan]), - cudf.Series([23678, 324, 1, 324], name="abc"), + lambda: cudf.Series(["abc", "l", "a", "abc", "z", "xyz"]), + lambda: cudf.Index([None, 1, 2, 3.3, None, 0.2]), + lambda: cudf.Series([0.1, 2, 3, None, np.nan]), + lambda: cudf.Series([23678, 324, 1, 324], name="abc"), ], ) @pytest.mark.parametrize("prefix_sep", ["-", "#"]) @pytest.mark.parametrize("prefix", [None, "hi"]) @pytest.mark.parametrize("dtype", ["uint8", "int16"]) def test_get_dummies_array_like(data, prefix_sep, prefix, dtype): - actual = cudf.get_dummies( - data, prefix=prefix, prefix_sep=prefix_sep, dtype=dtype - ) - if isinstance(data, (cudf.Series, cudf.BaseIndex)): - pd_data = data.to_pandas() - else: - pd_data = data + data = data() + pd_data = data.to_pandas() expected = pd.get_dummies( pd_data, prefix=prefix, prefix_sep=prefix_sep, dtype=dtype ) - utils.assert_eq(expected, actual) + + actual = cudf.get_dummies( + data, prefix=prefix, prefix_sep=prefix_sep, dtype=dtype + ) + + assert_eq(expected, actual) def test_get_dummies_array_like_with_nan(): ser = cudf.Series([0.1, 2, 3, None, np.nan], nan_as_null=False) expected = cudf.DataFrame( { - "a_null": [0, 0, 0, 1, 0], "a_0.1": [1, 0, 0, 0, 0], "a_2.0": [0, 1, 0, 0, 0], "a_3.0": [0, 0, 1, 0, 0], "a_nan": [0, 0, 0, 0, 1], + "a_null": [0, 0, 0, 1, 0], }, dtype="uint8", ) @@ -181,4 +181,4 @@ def test_get_dummies_array_like_with_nan(): ser, dummy_na=True, prefix="a", prefix_sep="_" ) - utils.assert_eq(expected, actual) + assert_eq(expected, actual) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 816eb6468b0..2bddd93ccb8 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1103,8 +1103,7 @@ def test_string_unique(item): gs = cudf.Series(item) # Pandas `unique` returns a numpy array pres = pd.Series(ps.unique()) - # cudf returns sorted unique with `None` placed before other strings - pres = pres.sort_values(na_position="first").reset_index(drop=True) + # cudf returns a cudf.Series gres = gs.unique() assert_eq(pres, gres)