Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose sorted groupby parameters to pylibcudf #16240

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/groupby.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ from cudf._lib.pylibcudf.libcudf.groupby cimport (
scan_request,
)
from cudf._lib.pylibcudf.libcudf.table.table cimport table
from cudf._lib.pylibcudf.libcudf.types cimport null_order, order

from .column cimport Column
from .table cimport Table
Expand All @@ -38,6 +39,9 @@ cdef class GroupByRequest:
cdef class GroupBy:
cdef unique_ptr[groupby] c_obj
cdef Table _keys
cdef unique_ptr[vector[order]] _column_order
cdef unique_ptr[vector[null_order]] _null_precedence

cpdef tuple aggregate(self, list requests)
cpdef tuple scan(self, list requests)
cpdef tuple shift(self, Table values, list offset, list fill_values)
Expand Down
33 changes: 29 additions & 4 deletions python/cudf/cudf/_lib/pylibcudf/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cython.operator cimport dereference
from libcpp.functional cimport reference_wrapper
from libcpp.memory cimport unique_ptr
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.pair cimport pair
from libcpp.utility cimport move
from libcpp.vector cimport vector
Expand All @@ -22,7 +22,7 @@ from cudf._lib.pylibcudf.libcudf.types cimport size_type
from .aggregation cimport Aggregation
from .column cimport Column
from .table cimport Table
from .types cimport null_policy, sorted
from .types cimport null_order, null_policy, order, sorted
from .utils cimport _as_vector


Expand Down Expand Up @@ -90,14 +90,39 @@ cdef class GroupBy:
Whether or not to include null rows in ``keys``. Default is null_policy.EXCLUDE.
keys_are_sorted : sorted, optional
Whether the keys are already sorted. Default is sorted.NO.
column_order : list[order]
Order of each key column if the keys are sorted. Default if not
provided uses order.ASCENDING. Ignored if keys_are_sorted is sorted.NO.
wence- marked this conversation as resolved.
Show resolved Hide resolved
null_precedence : list[null_order]
Where do nulls sort if the keys are sorted?
wence- marked this conversation as resolved.
Show resolved Hide resolved
Default is null_order.AFTER. Ignored if keys_are_sorted is sorted.NO.
"""
def __init__(
self,
Table keys,
null_policy null_handling=null_policy.EXCLUDE,
sorted keys_are_sorted=sorted.NO
sorted keys_are_sorted=sorted.NO,
list column_order=None,
list null_precedence=None,
):
self.c_obj.reset(new groupby(keys.view(), null_handling, keys_are_sorted))
self._column_order = make_unique[vector[order]]()
self._null_precedence = make_unique[vector[null_order]]()
if column_order is not None:
for o in column_order:
dereference(self._column_order).push_back(<order?>o)
if null_precedence is not None:
for o in null_precedence:
dereference(self._null_precedence).push_back(<null_order?>o)

self.c_obj.reset(
new groupby(
keys.view(),
null_handling,
keys_are_sorted,
dereference(self._column_order.get()),
dereference(self._null_precedence.get()),
)
)
# keep a reference to the keys table so it doesn't get
# deallocated from under us:
self._keys = keys
Expand Down
39 changes: 38 additions & 1 deletion python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from typing_extensions import Self

import polars as pl

__all__: list[str] = ["Column", "NamedColumn"]


Expand Down Expand Up @@ -76,12 +78,47 @@ def sorted_like(self, like: Column, /) -> Self:

See Also
--------
set_sorted
set_sorted, copy_metadata
"""
return self.set_sorted(
is_sorted=like.is_sorted, order=like.order, null_order=like.null_order
)

def copy_metadata(self, from_: pl.Series, /) -> Self:
"""
Copy metadata from a host series onto self.

Parameters
----------
from_
Polars series to copy metadata from

Returns
-------
Self with metadata set.

See Also
--------
set_sorted, sorted_like
"""
ascending = from_.flags["SORTED_ASC"]
descending = from_.flags["SORTED_DESC"]
if ascending or descending:
has_null_first = from_.item(0) is None
has_null_last = from_.item(-1) is None
wence- marked this conversation as resolved.
Show resolved Hide resolved
order = (
plc.types.Order.ASCENDING if ascending else plc.types.Order.DESCENDING
)
null_order = plc.types.NullOrder.BEFORE
if (descending and has_null_first) or (ascending and has_null_last):
null_order = plc.types.NullOrder.AFTER
return self.set_sorted(
is_sorted=plc.types.Sorted.YES,
order=order,
null_order=null_order,
)
return self

def set_sorted(
self,
*,
Expand Down
45 changes: 42 additions & 3 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
from functools import cached_property
from typing import TYPE_CHECKING, cast

import pyarrow as pa

import polars as pl

import cudf._lib.pylibcudf as plc

from cudf_polars.containers.column import NamedColumn
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence, Set

import pyarrow as pa
from typing_extensions import Self

import cudf
Expand Down Expand Up @@ -50,8 +52,16 @@ def to_polars(self) -> pl.DataFrame:
self.table,
[plc.interop.ColumnMetadata(name=c.name) for c in self.columns],
)

return cast(pl.DataFrame, pl.from_arrow(table))
return cast(pl.DataFrame, pl.from_arrow(table)).with_columns(
*(
pl.col(c.name).set_sorted(
descending=c.order == plc.types.Order.DESCENDING
)
if c.is_sorted
else pl.col(c.name)
for c in self.columns
)
)

@cached_property
def column_names_set(self) -> frozenset[str]:
Expand Down Expand Up @@ -83,6 +93,35 @@ def from_cudf(cls, df: cudf.DataFrame) -> Self:
]
)

@classmethod
def from_polars(cls, df: pl.DataFrame) -> Self:
"""
Create from a polars dataframe.

Parameters
----------
df
Polars dataframe to convert

Returns
-------
New dataframe representing the input.
"""
table = df.to_arrow()
wence- marked this conversation as resolved.
Show resolved Hide resolved
schema = table.schema
for i, field in enumerate(schema):
schema = schema.set(
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
)
# No-op if the schema is unchanged.
d_table = plc.interop.from_arrow(table.cast(schema))
return cls(
[
NamedColumn(column, h_col.name).copy_metadata(h_col)
wence- marked this conversation as resolved.
Show resolved Hide resolved
for column, h_col in zip(d_table.columns(), df.iter_columns())
]
)

@classmethod
def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
"""
Expand Down
29 changes: 28 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def __init__(
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
if self.name not in ("round", "unique", "setsorted"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
Expand Down Expand Up @@ -923,6 +923,33 @@ def do_evaluate(
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
elif self.name == "setsorted":
(column,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
(asc,) = self.options
order = (
plc.types.Order.ASCENDING
if asc == "ascending"
else plc.types.Order.DESCENDING
)
null_order = plc.types.NullOrder.BEFORE
if column.obj.null_count() > 0 and (n := column.obj.size()) > 1:
# PERF: This invokes four stream synchronisations!
wence- marked this conversation as resolved.
Show resolved Hide resolved
has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
has_nulls_last = not plc.copying.get_element(
column.obj, n - 1
).is_valid()
if (order == plc.types.Order.DESCENDING and has_nulls_first) or (
order == plc.types.Order.ASCENDING and has_nulls_last
):
null_order = plc.types.NullOrder.AFTER
return column.set_sorted(
is_sorted=plc.types.Sorted.YES,
order=order,
null_order=null_order,
)
raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first
Expand Down
29 changes: 10 additions & 19 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import DataFrame, NamedColumn
from cudf_polars.utils import dtypes, sorting
from cudf_polars.utils import sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
Expand Down Expand Up @@ -385,17 +385,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
pdf = pl.DataFrame._from_pydf(self.df)
if self.projection is not None:
pdf = pdf.select(self.projection)
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
schema = schema.set(
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
)
# No-op if the schema is unchanged.
table = table.cast(schema)
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
)
df = DataFrame.from_polars(pdf)
assert all(
c.obj.type() == dtype for c, dtype in zip(df.columns, self.schema.values())
)
Expand Down Expand Up @@ -542,16 +532,17 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
keys = broadcast(
*(k.evaluate(df) for k in self.keys), target_length=df.num_rows
)
# TODO: use sorted information, need to expose column_order
# and null_precedence in pylibcudf groupby constructor
# sorted = (
# plc.types.Sorted.YES
# if all(k.is_sorted for k in keys)
# else plc.types.Sorted.NO
# )
sorted = (
plc.types.Sorted.YES
if all(k.is_sorted for k in keys)
else plc.types.Sorted.NO
)
grouper = plc.groupby.GroupBy(
plc.Table([k.obj for k in keys]),
null_handling=plc.types.NullPolicy.INCLUDE,
keys_are_sorted=sorted,
column_order=[k.order for k in keys],
null_precedence=[k.null_order for k in keys],
)
# TODO: uniquify
requests = []
Expand Down
38 changes: 38 additions & 0 deletions python/cudf_polars/tests/containers/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

import polars as pl

import cudf._lib.pylibcudf as plc

from cudf_polars.containers import DataFrame, NamedColumn
Expand Down Expand Up @@ -90,3 +92,39 @@ def test_shallow_copy():
)
assert df.columns[0].is_sorted == plc.types.Sorted.YES
assert copy.columns[0].is_sorted == plc.types.Sorted.NO


@pytest.mark.parametrize("nulls_last", [True, False])
def test_flags_preserved(with_nulls, nulls_last):
wence- marked this conversation as resolved.
Show resolved Hide resolved
values = [1, 2, -1, 2, 4, 5]
if with_nulls:
values[4] = None
df = pl.DataFrame({"a": values, "b": values, "c": values})

df = df.select(
pl.col("a").sort(descending=False, nulls_last=nulls_last),
pl.col("b").sort(descending=True, nulls_last=nulls_last),
pl.col("c"),
)

gf = DataFrame.from_polars(df)

a_null_order = (
plc.types.NullOrder.AFTER
if nulls_last and with_nulls
else plc.types.NullOrder.BEFORE
)
b_null_order = (
plc.types.NullOrder.AFTER
if not nulls_last and with_nulls
else plc.types.NullOrder.BEFORE
)
a, b, c = gf.columns
assert a.is_sorted == plc.types.Sorted.YES
assert a.order == plc.types.Order.ASCENDING
assert a.null_order == a_null_order
assert b.is_sorted == plc.types.Sorted.YES
assert b.order == plc.types.Order.DESCENDING
assert b.null_order == b_null_order
assert c.is_sorted == plc.types.Sorted.NO
assert df.flags == gf.to_polars().flags
8 changes: 1 addition & 7 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@ def dtype(request):
return request.param


@pytest.fixture(
params=[
False,
pytest.param(True, marks=pytest.mark.xfail(reason="No handler for set_sorted")),
],
ids=["unsorted", "sorted"],
)
@pytest.fixture(params=[False, True], ids=["unsorted", "sorted"])
def is_sorted(request):
return request.param

Expand Down
Loading
Loading