Skip to content

Commit

Permalink
Simplify old logic
Browse files Browse the repository at this point in the history
  • Loading branch information
vyasr committed Feb 2, 2024
1 parent 21176f2 commit 97cb7a5
Showing 1 changed file with 18 additions and 51 deletions.
69 changes: 18 additions & 51 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -94,41 +94,17 @@ def _(dtype: DecimalDtype):
return _DECIMAL_AGGS


cdef _agg_result_from_pylibcudf_tables(
pylibcudf_results,
set column_included,
int n_input_columns
):
"""Construct the list of result columns from libcudf result. The result
contains the same number of lists as the number of input columns. Result
for an input column that has no applicable aggregations is an empty list.
"""
cdef:
int i
int result_index = 0
result_columns = []
for i in range(n_input_columns):
if i in column_included:
result_columns.append(
columns_from_pylibcudf_table(pylibcudf_results[result_index])
)
result_index += 1
else:
result_columns.append([])
return result_columns


cdef class GroupBy:
cdef dict __dict__

def __init__(self, keys, dropna=True):
self._groupby = pylibcudf.groupby.GroupBy(
pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in keys]),
pylibcudf.types.NullPolicy.EXCLUDE if dropna
else pylibcudf.types.NullPolicy.INCLUDE
)

with acquire_spill_lock() as spill_lock:
self._groupby = pylibcudf.groupby.GroupBy(
pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in keys]),
pylibcudf.types.NullPolicy.EXCLUDE if dropna
else pylibcudf.types.NullPolicy.INCLUDE
)

# We spill lock the columns while this GroupBy instance is alive.
self._spill_lock = spill_lock

Expand Down Expand Up @@ -180,45 +156,36 @@ cdef class GroupBy:
-------
Frame of aggregated values
"""
alg = "scan" if _is_all_scan_aggregate(aggregations) else "aggregate"

allow_empty = all(len(v) == 0 for v in aggregations)

included_aggregations = []
column_included = set()
column_included = []
requests = []
for i, (col, aggs) in enumerate(zip(values, aggregations)):
dtype = col.dtype

valid_aggregations = get_valid_aggregation(dtype)
valid_aggregations = get_valid_aggregation(col.dtype)
included_aggregations_i = []

col_aggregations = []
for agg in aggs:
agg_obj = make_groupby_aggregation(agg)
if (valid_aggregations == "ALL"
or agg_obj.kind in valid_aggregations):
if valid_aggregations == "ALL" or agg_obj.kind in valid_aggregations:
included_aggregations_i.append((agg, agg_obj.kind))
col_aggregations.append(agg_obj.c_obj)
included_aggregations.append(included_aggregations_i)
if col_aggregations:
requests.append(pylibcudf.groupby.GroupByRequest(
col.to_pylibcudf(mode="read"), col_aggregations
))
column_included.add(i)
if not requests and not allow_empty:
raise DataError("All requested aggregations are unsupported.")
column_included.append(i)

keys, results = self._groupby.aggregate(requests) if alg == "aggregate" \
else self._groupby.scan(requests)
if not requests and any(len(v) > 0 for v in aggregations):
raise DataError("All requested aggregations are unsupported.")

grouped_keys = columns_from_pylibcudf_table(keys)
keys, results = self._groupby.scan(requests) if \
_is_all_scan_aggregate(aggregations) else self._groupby.aggregate(requests)

result_columns = _agg_result_from_pylibcudf_tables(
results, column_included, len(values)
)
result_columns = [[] for _ in range(len(values))]
for i, result in zip(column_included, results):
result_columns[i] = columns_from_pylibcudf_table(result)

return result_columns, grouped_keys, included_aggregations
return result_columns, columns_from_pylibcudf_table(keys), included_aggregations

def shift(self, list values, int periods, list fill_values):
keys, shifts = self._groupby.shift(
Expand Down

0 comments on commit 97cb7a5

Please sign in to comment.