Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom sort functions for dask-cudf sort_values #9789

Merged
merged 7 commits into from
Jan 14, 2022
18 changes: 15 additions & 3 deletions python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,27 @@ def sort_values(
set_divisions=False,
ascending=True,
na_position="last",
sort_function=None,
sort_function_kwargs=None,
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
if kwargs:
raise ValueError(
f"Unsupported input arguments passed : {list(kwargs.keys())}"
)

sort_kwargs = {
"by": by,
"ascending": ascending,
"na_position": na_position,
}
if sort_function is None:
sort_function = M.sort_values
if sort_function_kwargs is not None:
sort_kwargs.update(sort_function_kwargs)

if self.npartitions == 1:
df = self.map_partitions(
M.sort_values, by, ascending=ascending, na_position=na_position
)
df = self.map_partitions(sort_function, **sort_kwargs)
else:
df = sorting.sort_values(
self,
Expand All @@ -256,6 +266,8 @@ def sort_values(
ignore_index=ignore_index,
ascending=ascending,
na_position=na_position,
sort_function=sort_function,
sort_function_kwargs=sort_kwargs,
)

if ignore_index:
Expand Down
7 changes: 3 additions & 4 deletions python/dask_cudf/dask_cudf/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dask.dataframe.core import DataFrame, Index, Series
from dask.dataframe.shuffle import rearrange_by_column
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M

import cudf as gd
from cudf.api.types import is_categorical_dtype
Expand Down Expand Up @@ -222,6 +221,8 @@ def sort_values(
ignore_index=False,
ascending=True,
na_position="last",
sort_function=None,
sort_function_kwargs=None,
):
"""Sort by the given list/tuple of column names."""
if na_position not in ("first", "last"):
Expand Down Expand Up @@ -263,9 +264,7 @@ def sort_values(
df3.divisions = (None,) * (df3.npartitions + 1)

# Step 3 - Return final sorted df
df4 = df3.map_partitions(
M.sort_values, by, ascending=ascending, na_position=na_position
)
df4 = df3.map_partitions(sort_function, **sort_function_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something feels off here. We are requiring that the user specify sort_function, but the API makes it seem optional. I worry that we are now silently ignoring acsending and na_position (and maybe even by?).

What if down-stream users are implementing code with sorting.sort_values directly? I don't think that is good/recommended practice, but the API we are changing seems "public" to me (making this a breaking change).

Perhaps a simpler (non-breaking) solution would be to remove most of the changes from DataFrame.sort_values, pass through sort_function and sort_function_kwargs into here, and implement the sort_function/sort_function_kwargs default logic here (in sorting.sort_values). Does this seem reasonable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense and is a valid concern - my only comment is that we ideally still want to allow for custom sorting functions in the npartitions == 1 case that is handled directly in DataFrame.sort_values, so I think it might also make sense to move the following logic:

        if self.npartitions == 1:
            df = self.map_partitions(sort_function, **sort_kwargs)

into sorting.sort_values as well, unless there's a reason that's not immediately obvious to me why we would want to keep the single partition case separate?

Also noting that this is also a concern for the upstream implementation of this, so depending on what we decide on here I will open up a follow up PR to address this in Dask.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also noting that this is also a concern for the upstream implementation of this, so depending on what we decide on here I will open up a follow up PR to address this in Dask.

Good point! I definitely like the simplification you made here. So it probably makes sense to do something similar upstream.

if not isinstance(divisions, gd.DataFrame) and set_divisions:
# Can't have multi-column divisions elsewhere in dask (yet)
df4.divisions = methods.tolist(divisions)
Expand Down
19 changes: 19 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,22 @@ def test_sort_values_with_nulls(data, by, ascending, na_position):

# cudf ordering for nulls is non-deterministic
dd.assert_eq(got[by], expect[by], check_index=False)


@pytest.mark.parametrize("by", [["a", "b"], ["b", "a"]])
@pytest.mark.parametrize("nparts", [1, 10])
def test_sort_values_custom_function(by, nparts):
df = cudf.DataFrame({"a": [1, 2, 3] * 20, "b": [4, 5, 6, 7] * 15})
ddf = dd.from_pandas(df, npartitions=nparts)

def f(partition, by_columns, ascending, na_position, **kwargs):
return partition.sort_values(
by_columns, ascending=ascending, na_position=na_position
)

with dask.config.set(scheduler="single-threaded"):
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
got = ddf.sort_values(
by=by[0], sort_function=f, sort_function_kwargs={"by_columns": by}
)
expect = df.sort_values(by=by)
dd.assert_eq(got, expect, check_index=False)