-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add handling for nested dicts in dask-cudf groupby #9054
Changes from 2 commits
1964f53
2805907
cc82256
05afea7
0954d23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,34 +198,8 @@ def groupby_agg( | |
in `dask.dataframe`, because it allows the cudf backend to | ||
perform multiple aggregations at once. | ||
""" | ||
|
||
# Deal with default split_out and split_every params | ||
if split_every is False: | ||
split_every = ddf.npartitions | ||
split_every = split_every or 8 | ||
split_out = split_out or 1 | ||
|
||
# Standardize `gb_cols` and `columns` lists | ||
aggs = _redirect_aggs(aggs_in.copy()) | ||
if isinstance(gb_cols, str): | ||
gb_cols = [gb_cols] | ||
columns = [c for c in ddf.columns if c not in gb_cols] | ||
str_cols_out = False | ||
if isinstance(aggs, dict): | ||
# Use `str_cols_out` to specify if the output columns | ||
# will have str (rather than MultiIndex/tuple) names. | ||
# This happens when all values in the `aggs` dict are | ||
# strings (no lists) | ||
str_cols_out = True | ||
for col in aggs: | ||
if isinstance(aggs[col], str) or callable(aggs[col]): | ||
aggs[col] = [aggs[col]] | ||
else: | ||
str_cols_out = False | ||
if col in gb_cols: | ||
columns.append(col) | ||
|
||
# Assert that aggregations are supported | ||
aggs = _redirect_aggs(aggs_in) | ||
_supported = { | ||
"count", | ||
"mean", | ||
|
@@ -244,10 +218,39 @@ def groupby_agg( | |
f"Aggregations must be specified with dict or list syntax." | ||
) | ||
|
||
# Always convert aggs to dict for consistency | ||
# Deal with default split_out and split_every params | ||
if split_every is False: | ||
split_every = ddf.npartitions | ||
split_every = split_every or 8 | ||
split_out = split_out or 1 | ||
|
||
# Standardize `gb_cols`, `columns`, and `aggs` | ||
if isinstance(gb_cols, str): | ||
gb_cols = [gb_cols] | ||
columns = [c for c in ddf.columns if c not in gb_cols] | ||
if isinstance(aggs, list): | ||
aggs = {col: aggs for col in columns} | ||
|
||
# Assert if our output will have a MultiIndex; this will be the case if | ||
# any value in the `aggs` dict is not a string (i.e. multiple/named | ||
# aggregations per column) | ||
str_cols_out = True | ||
aggs_renames = {} | ||
for col in aggs: | ||
if isinstance(aggs[col], str) or callable(aggs[col]): | ||
aggs[col] = [aggs[col]] | ||
elif isinstance(aggs[col], dict): | ||
str_cols_out = False | ||
col_aggs = [] | ||
for k, v in aggs[col].items(): | ||
aggs_renames[_make_name(col, v, sep=sep)] = k | ||
col_aggs.append(v) | ||
aggs[col] = col_aggs | ||
else: | ||
str_cols_out = False | ||
if col in gb_cols: | ||
columns.append(col) | ||
|
||
# Begin graph construction | ||
dsk = {} | ||
token = tokenize(ddf, gb_cols, aggs) | ||
|
@@ -314,6 +317,15 @@ def groupby_agg( | |
for col in aggs: | ||
_aggs[col] = _aggs[col][0] | ||
_meta = ddf._meta.groupby(gb_cols, as_index=as_index).agg(_aggs) | ||
if aggs_renames: | ||
col_array = [] | ||
agg_array = [] | ||
for col, agg in _meta.columns: | ||
col_array.append(col) | ||
agg_array.append( | ||
aggs_renames.get(_make_name(col, agg, sep=sep), agg) | ||
) | ||
_meta.columns = pd.MultiIndex.from_arrays([col_array, agg_array]) | ||
for s in range(split_out): | ||
dsk[(gb_agg_name, s)] = ( | ||
_finalize_gb_agg, | ||
|
@@ -326,6 +338,7 @@ def groupby_agg( | |
sort, | ||
sep, | ||
str_cols_out, | ||
aggs_renames, | ||
) | ||
|
||
divisions = [None] * (split_out + 1) | ||
|
@@ -350,6 +363,9 @@ def _redirect_aggs(arg): | |
for col in arg: | ||
if isinstance(arg[col], list): | ||
new_arg[col] = [redirects.get(agg, agg) for agg in arg[col]] | ||
elif isinstance(arg[col], dict): | ||
for k, v in arg[col].items(): | ||
new_arg[col] = {k: redirects.get(v, v)} | ||
else: | ||
new_arg[col] = redirects.get(arg[col], arg[col]) | ||
return new_arg | ||
|
@@ -367,6 +383,8 @@ def _is_supported(arg, supported: set): | |
for col in arg: | ||
if isinstance(arg[col], list): | ||
_global_set = _global_set.union(set(arg[col])) | ||
elif isinstance(arg[col], dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does order matter for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't matter here, since we only need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, sounds good! For the most part the code looks pretty good to me! |
||
_global_set = _global_set.union(set(arg[col].values())) | ||
else: | ||
_global_set.add(arg[col]) | ||
else: | ||
|
@@ -460,10 +478,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep): | |
agg = col.split(sep)[-1] | ||
if agg in ("count", "sum"): | ||
agg_dict[col] = ["sum"] | ||
elif agg in ("min", "max"): | ||
elif agg in ("min", "max", "collect"): | ||
agg_dict[col] = [agg] | ||
elif agg == "collect": | ||
agg_dict[col] = ["collect"] | ||
else: | ||
raise ValueError(f"Unexpected aggregation: {agg}") | ||
|
||
|
@@ -508,6 +524,7 @@ def _finalize_gb_agg( | |
sort, | ||
sep, | ||
str_cols_out, | ||
aggs_renames, | ||
): | ||
""" Final aggregation task. | ||
|
||
|
@@ -564,7 +581,7 @@ def _finalize_gb_agg( | |
else: | ||
name, agg = col.split(sep) | ||
col_array.append(name) | ||
agg_array.append(agg) | ||
agg_array.append(aggs_renames.get(col, agg)) | ||
if str_cols_out: | ||
gb.columns = col_array | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like that we have to do the aggregation renames for both
_meta
and the groupby result, but this is required so that we have the correctfinal_columns
for the last step of_finalize_gb_agg()
. It would be nice if we also supported nested dict aggregations in cuDF's groupby so that_meta
would have the correct index without any additional steps in dask-cuDF.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @shwina said this could be done but would require some effort. Since pandas does not support nested dicts it seemed like cuDF did not have to go down this path. We could be wrong and if you feel strongly you should speak up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If pandas doesn't support something ugly, I'd lean away from doing it in cudf for the sake of dask-cudf logic :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I generally agree - there could be larger motivations to want nested renaming support for groupby in cuDF, but I don't think this case alone is a good enough reason to work on it