Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collect list to dask-cudf groupby aggregations #8045

Merged
merged 10 commits into from
Jul 6, 2021
19 changes: 13 additions & 6 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {"count", "mean", "std", "var", "sum", "min", "max", list}
if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
Expand Down Expand Up @@ -107,7 +107,7 @@ def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {"count", "mean", "std", "var", "sum", "min", "max", list}
if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
Expand Down Expand Up @@ -145,7 +145,7 @@ def groupby_agg(

This aggregation algorithm only supports the following options:

{"count", "mean", "std", "var", "sum", "min", "max"}
{"count", "mean", "std", "var", "sum", "min", "max", list}

This "optimized" approach is more performant than the algorithm
in `dask.dataframe`, because it allows the cudf backend to
Expand All @@ -171,15 +171,15 @@ def groupby_agg(
# strings (no lists)
str_cols_out = True
for col in aggs:
if isinstance(aggs[col], str):
if isinstance(aggs[col], str) or callable(aggs[col]):
aggs[col] = [aggs[col]]
else:
str_cols_out = False
if col in gb_cols:
columns.append(col)

# Assert that aggregations are supported
_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {"count", "mean", "std", "var", "sum", "min", "max", list}
if not _is_supported(aggs, _supported):
raise ValueError(
f"Supported aggs include {_supported} for groupby_agg API. "
Expand Down Expand Up @@ -255,7 +255,12 @@ def groupby_agg(
# be str, rather than tuples).
for col in aggs:
_aggs[col] = _aggs[col][0]
_meta = ddf._meta.groupby(gb_cols, as_index=as_index).agg(_aggs)
try:
_meta = ddf._meta.groupby(gb_cols, as_index=as_index).agg(_aggs)
except NotImplementedError:
_meta = ddf._meta_nonempty.groupby(gb_cols, as_index=as_index).agg(
_aggs
)
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
for s in range(split_out):
dsk[(gb_agg_name, s)] = (
_finalize_gb_agg,
Expand Down Expand Up @@ -381,6 +386,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep):
agg_dict[col] = ["sum"]
elif agg in ("min", "max"):
agg_dict[col] = [agg]
elif agg == "list":
agg_dict[col] = [list]
else:
raise ValueError(f"Unexpected aggregation: {agg}")

Expand Down