Skip to content

Commit

Permalink
Allow DataFrame.sort_values(by=) to select an index level (#16519)
Browse files Browse the repository at this point in the history
closes #14794

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Matthew Murray (https://github.com/Matt711)

URL: #16519
  • Loading branch information
mroeschke authored Aug 9, 2024
1 parent 4446cf0 commit 16aa0ea
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
13 changes: 12 additions & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
import pickle
import warnings
from collections.abc import Hashable
from functools import cache, cached_property
from numbers import Number
from typing import TYPE_CHECKING, Any, Literal, MutableMapping, cast
Expand Down Expand Up @@ -60,7 +61,7 @@
from cudf.utils.utils import _warn_no_dask_cudf, search_range

if TYPE_CHECKING:
from collections.abc import Generator, Hashable, Iterable
from collections.abc import Generator, Iterable
from datetime import tzinfo


Expand Down Expand Up @@ -450,6 +451,16 @@ def __getitem__(self, index):
return self.start + index * self.step
return self._as_int_index()[index]

def _get_columns_by_label(self, labels) -> Index:
# used in .sort_values
if isinstance(labels, Hashable):
if labels == self.name:
return self._as_int_index()
elif is_list_like(labels):
if list(self.names) == list(labels):
return self._as_int_index()
raise KeyError(labels)

@_performance_tracking
def equals(self, other) -> bool:
if isinstance(other, RangeIndex):
Expand Down
26 changes: 25 additions & 1 deletion python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3592,10 +3592,34 @@ def sort_values(
if len(self) == 0:
return self

try:
by_in_columns = self._get_columns_by_label(by)
except KeyError:
by_in_columns = None
if self.ndim == 1:
# For Series case, we're never selecting an index level.
by_in_index = None
else:
try:
by_in_index = self.index._get_columns_by_label(by)
except KeyError:
by_in_index = None

if by_in_columns is not None and by_in_index is not None:
raise ValueError(
f"{by=} appears in the {type(self).__name__} columns "
"and as an index level which is ambiguous."
)
elif by_in_columns is not None:
by_columns = by_in_columns
elif by_in_index is not None:
by_columns = by_in_index
else:
raise KeyError(by)
# argsort the `by` column
out = self._gather(
GatherMap.from_column_unchecked(
self._get_columns_by_label(by)._get_sorted_inds(
by_columns._get_sorted_inds(
ascending=ascending, na_position=na_position
),
len(self),
Expand Down
20 changes: 20 additions & 0 deletions python/cudf/cudf/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,23 @@ def test_dataframe_scatter_by_map_empty():
df = DataFrame({"a": [], "b": []}, dtype="float64")
scattered = df.scatter_by_map(df["a"])
assert len(scattered) == 0


def test_sort_values_by_index_level():
df = pd.DataFrame({"a": [1, 3, 2]}, index=pd.Index([1, 3, 2], name="b"))
cudf_df = DataFrame.from_pandas(df)
result = cudf_df.sort_values("b")
expected = df.sort_values("b")
assert_eq(result, expected)


def test_sort_values_by_ambiguous():
df = pd.DataFrame({"a": [1, 3, 2]}, index=pd.Index([1, 3, 2], name="a"))
cudf_df = DataFrame.from_pandas(df)

assert_exceptions_equal(
lfunc=df.sort_values,
rfunc=cudf_df.sort_values,
lfunc_args_and_kwargs=(["a"], {}),
rfunc_args_and_kwargs=(["a"], {}),
)

0 comments on commit 16aa0ea

Please sign in to comment.