From a82dc4b5d7b7b6885b17f71436b209e5e4bb4437 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 24 Mar 2021 09:55:28 -0700 Subject: [PATCH] fix dataframe argsort return type --- python/cudf/cudf/core/dataframe.py | 24 +++++++++++++++++++++++- python/cudf/cudf/tests/test_dataframe.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index bd009a9ad84..b5f57356698 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -3841,10 +3841,32 @@ def argsort(self, ascending=True, na_position="last"): - Support axis='index' only. - Not supporting: inplace, kind - Ascending can be a list of bools to control per column + + Examples + -------- + >>> import cudf + >>> df = cudf.DataFrame({'a':[10, 0, 2], 'b':[-10, 10, 1]}) + >>> df + a b + 0 10 -10 + 1 0 10 + 2 2 1 + >>> inds = df.argsort() + >>> inds + 0 1 + 1 2 + 2 0 + dtype: int32 + >>> df.take(inds) + a b + 1 0 10 + 2 2 1 + 0 10 -10 """ - return self._get_sorted_inds( + inds_col = self._get_sorted_inds( ascending=ascending, na_position=na_position ) + return cudf.Series(inds_col) @annotate("SORT_INDEX", color="red", domain="cudf_python") def sort_index( diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 76a02d5e74a..d72b88f1713 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8495,3 +8495,24 @@ def test_explode(data, labels, ignore_index, p_index, label_to_explode): got = gdf.explode(label_to_explode, ignore_index) assert_eq(expect, got, check_dtype=False) + + +@pytest.mark.parametrize( + "df,ascending,expected", + [ + ( + cudf.DataFrame({"a": [10, 0, 2], "b": [-10, 10, 1]}), + True, + cudf.Series([1, 2, 0], dtype="int32"), + ), + ( + cudf.DataFrame({"a": [10, 0, 2], "b": [-10, 10, 1]}), + False, + cudf.Series([0, 2, 1], dtype="int32"), + ), + ], +) +def test_dataframe_argsort(df, ascending, expected): + actual = df.argsort(ascending=ascending) + + assert_eq(actual, expected)