diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 094da09ab08..7f40428c1b8 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -5,6 +5,7 @@ import operator import pickle import warnings +from collections.abc import Hashable from functools import cache, cached_property from numbers import Number from typing import TYPE_CHECKING, Any, Literal, MutableMapping, cast @@ -60,7 +61,7 @@ from cudf.utils.utils import _warn_no_dask_cudf, search_range if TYPE_CHECKING: - from collections.abc import Generator, Hashable, Iterable + from collections.abc import Generator, Iterable from datetime import tzinfo @@ -450,6 +451,16 @@ def __getitem__(self, index): return self.start + index * self.step return self._as_int_index()[index] + def _get_columns_by_label(self, labels) -> Index: + # used in .sort_values + if isinstance(labels, Hashable): + if labels == self.name: + return self._as_int_index() + elif is_list_like(labels): + if list(self.names) == list(labels): + return self._as_int_index() + raise KeyError(labels) + @_performance_tracking def equals(self, other) -> bool: if isinstance(other, RangeIndex): diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 24d947a574a..3b44a0f5864 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -3592,10 +3592,34 @@ def sort_values( if len(self) == 0: return self + try: + by_in_columns = self._get_columns_by_label(by) + except KeyError: + by_in_columns = None + if self.ndim == 1: + # For Series case, we're never selecting an index level. + by_in_index = None + else: + try: + by_in_index = self.index._get_columns_by_label(by) + except KeyError: + by_in_index = None + + if by_in_columns is not None and by_in_index is not None: + raise ValueError( + f"{by=} appears in the {type(self).__name__} columns " + "and as an index level which is ambiguous." + ) + elif by_in_columns is not None: + by_columns = by_in_columns + elif by_in_index is not None: + by_columns = by_in_index + else: + raise KeyError(by) # argsort the `by` column out = self._gather( GatherMap.from_column_unchecked( - self._get_columns_by_label(by)._get_sorted_inds( + by_columns._get_sorted_inds( ascending=ascending, na_position=na_position ), len(self), diff --git a/python/cudf/cudf/tests/test_sorting.py b/python/cudf/cudf/tests/test_sorting.py index a8ffce6e88b..2cf2259d9ec 100644 --- a/python/cudf/cudf/tests/test_sorting.py +++ b/python/cudf/cudf/tests/test_sorting.py @@ -405,3 +405,23 @@ def test_dataframe_scatter_by_map_empty(): df = DataFrame({"a": [], "b": []}, dtype="float64") scattered = df.scatter_by_map(df["a"]) assert len(scattered) == 0 + + +def test_sort_values_by_index_level(): + df = pd.DataFrame({"a": [1, 3, 2]}, index=pd.Index([1, 3, 2], name="b")) + cudf_df = DataFrame.from_pandas(df) + result = cudf_df.sort_values("b") + expected = df.sort_values("b") + assert_eq(result, expected) + + +def test_sort_values_by_ambiguous(): + df = pd.DataFrame({"a": [1, 3, 2]}, index=pd.Index([1, 3, 2], name="a")) + cudf_df = DataFrame.from_pandas(df) + + assert_exceptions_equal( + lfunc=df.sort_values, + rfunc=cudf_df.sort_values, + lfunc_args_and_kwargs=(["a"], {}), + rfunc_args_and_kwargs=(["a"], {}), + )