Skip to content

Commit

Permalink
SNOW-1476363: Add groupby().size() pandas API (#1763)
Browse files Browse the repository at this point in the history
Signed-off-by: Devin Petersohn <[email protected]>
  • Loading branch information
sfc-gh-dpetersohn authored Jun 14, 2024
1 parent 42ac032 commit 04370fe
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
- Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`.
- Added support for `Series.str.__getitem__` (`Series.str[...]`).
- Added support for `Series.str.lstrip` and `Series.str.rstrip`.
- Added support for `DataFrameGroupby.size` and `SeriesGroupby.size`.
- Added support for `DataFrame.expanding` and `Series.expanding` for aggregations `count`, `sum`, `min`, `max`, `mean`, `std`, and `var` with `axis=0`.
- Added support for `DataFrame.rolling` and `Series.rolling` for aggregation `count` with `axis=0`.
- Added support for `Series.str.match`.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/modin/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ GroupBy
DataFrameGroupBy.quantile
DataFrameGroupBy.rank
DataFrameGroupBy.shift
DataFrameGroupBy.size
DataFrameGroupBy.std
DataFrameGroupBy.sum
DataFrameGroupBy.tail
Expand All @@ -76,6 +77,7 @@ GroupBy
SeriesGroupBy.quantile
SeriesGroupBy.rank
SeriesGroupBy.shift
SeriesGroupBy.size
SeriesGroupBy.std
SeriesGroupBy.sum
SeriesGroupBy.tail
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/groupby_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Computations/descriptive stats
| ``shift`` | P | ``Y`` if ``axis = 0``, ``freq`` is None, |
| | | ``level`` is None, and ``by`` is in the columns |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``size`` | N | |
| ``size`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``skew`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
20 changes: 19 additions & 1 deletion src/snowflake/snowpark/modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,21 @@ def all(self, skipna=True):

def size(self):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
ErrorMessage.method_not_implemented_error(name="size", class_="GroupBy")
result = self._wrap_aggregation(
type(self._query_compiler).groupby_size,
numeric_only=False,
)
if not isinstance(result, Series):
result = result.squeeze(axis=1)
if not self._kwargs.get("as_index") and not isinstance(result, Series):
result = (
result.rename(columns={MODIN_UNNAMED_SERIES_LABEL: "index"})
if MODIN_UNNAMED_SERIES_LABEL in result.columns
else result
)
elif isinstance(self._df, Series):
result.name = self._df.name
return result

def sum(
self,
Expand Down Expand Up @@ -1190,6 +1204,10 @@ def is_monotonic_increasing(self):
name="is_monotonic_increasing", class_="GroupBy"
)

def size(self):
# TODO: Remove this once SNOW-1478924 is fixed
return super().size().rename(self._df.columns[-1])

def aggregate(
self,
func: Optional[AggFuncType] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import re
import typing
import uuid
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta, tzinfo
from typing import Any, Callable, Literal, Optional, Union, get_args
Expand Down Expand Up @@ -3643,6 +3644,90 @@ def groupby_shift(
)
)

def groupby_size(
self,
by: Any,
axis: int,
groupby_kwargs: dict[str, Any],
agg_args: tuple[Any],
agg_kwargs: dict[str, Any],
drop: bool = False,
**kwargs: dict[str, Any],
) -> "SnowflakeQueryCompiler":
"""
compute groupby with size.
With a dataframe created with following:
import pandas as pd

data = [[1,2,3], [1, 5, 6], [2, 5, 8], [2, 6, 9]]

df = pd.DataFrame(data, columns=["a", "b", "c"], index = ["tuna", "salmon", "catfish", "goldfish"])

df

a b c

tuna 1 2 3
salmon 1 5 6
catfish 2 5 8
goldfish 2 6 9

df.groupby("a").size()

a
1 2
2 2
dtype: int64


Args:
by: mapping, series, callable, label, pd.Grouper, BaseQueryCompiler, list of such.
Use this to determine the groups.
axis: 0 (index) or 1 (columns).
groupby_kwargs: dict
keyword arguments passed for the groupby.
agg_args: tuple
The aggregation args, unused in `groupby_size`.
agg_kwargs: dict
The aggregation keyword args, unused in `groupby_size`.
drop: bool
Drop the `by` column, unused in `groupby_size`.
Returns:
SnowflakeQueryCompiler: The result of groupby_size()
"""
level = groupby_kwargs.get("level", None)
is_supported = check_is_groupby_supported_by_snowflake(by, level, axis)

if not is_supported:
ErrorMessage.not_implemented(
"Snowpark pandas GroupBy.size does not yet support pd.Grouper, axis == 1, by != None and level != None, by containing any non-pandas hashable labels, or unsupported aggregation parameters."
)
if not is_list_like(by):
by = [by]
positions_col_name = f"__TEMP_POS_NAME_{uuid.uuid4().hex[-6:]}__"
# We reset index twice to ensure we perform the count aggregation on the row
# positions (which cannot be null). We name the column a unique new name to
# avoid collisions. We rename them to their final names at the end.
result = (
self.reset_index(drop=True)
.reset_index(drop=False, names=positions_col_name)
.take_2d_labels(slice(None), [positions_col_name] + by)
.groupby_agg(
by,
"count",
axis,
groupby_kwargs,
(),
{},
)
)
if not groupby_kwargs.get("as_index", True):
return result.rename(columns_renamer={positions_col_name: "size"})
else:
return result.rename(
columns_renamer={positions_col_name: MODIN_UNNAMED_SERIES_LABEL}
)

def groupby_groups(
self,
by: Any,
Expand Down
87 changes: 87 additions & 0 deletions src/snowflake/snowpark/modin/plugin/docstrings/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,48 @@ def all():
pass

def size():
"""
Compute group sizes.
Returns
-------
DataFrame or Series
Number of rows in each group as a Series if as_index is True
or a DataFrame if as_index is False.
Examples
--------
>>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["owl", "toucan", "eagle"])
>>> df
a b c
owl 1 2 3
toucan 1 5 6
eagle 7 8 9
>>> df.groupby("a").size()
a
1 2
7 1
dtype: int64
For SeriesGroupBy:
>>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["owl", "toucan", "eagle"])
>>> df
a b c
owl 1 2 3
toucan 1 5 6
eagle 7 8 9
>>> df.groupby("a")["b"].size()
a
1 2
7 1
Name: b, dtype: int64
"""
pass

@doc(
Expand Down Expand Up @@ -1922,6 +1964,51 @@ def nunique():
Series
"""

def size():
"""
Compute group sizes.
Returns
-------
DataFrame or Series
Number of rows in each group as a Series if as_index is True
or a DataFrame if as_index is False.
Examples
--------
>>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["owl", "toucan", "eagle"])
>>> df
a b c
owl 1 2 3
toucan 1 5 6
eagle 7 8 9
>>> df.groupby("a").size()
a
1 2
7 1
dtype: int64
For SeriesGroupBy:
>>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["owl", "toucan", "eagle"])
>>> df
a b c
owl 1 2 3
toucan 1 5 6
eagle 7 8 9
>>> df.groupby("a")["b"].size()
a
1 2
7 1
Name: b, dtype: int64
"""
pass

def unique(self):
pass

Expand Down
91 changes: 91 additions & 0 deletions tests/integ/modin/groupby/test_groupby_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import modin.pandas as pd
import numpy as np
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import eval_snowpark_pandas_result


@pytest.mark.parametrize(
"by",
[
"col1_grp",
"col2_int64",
"col3_int_identical",
"col4_int32",
"col6_mixed",
"col7_bool",
"col8_bool_missing",
"col9_int_missing",
"col10_mixed_missing",
["col1_grp", "col2_int64"],
["col6_mixed", "col7_bool", "col3_int_identical"],
],
)
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_size(by, as_index):
snowpark_pandas_df = pd.DataFrame(
{
"col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"],
"col2_int64": np.arange(9, dtype="int64") // 3,
"col3_int_identical": [2] * 9,
"col4_int32": np.arange(9, dtype="int32") // 4,
"col5_int16": np.arange(9, dtype="int16") // 3,
"col6_mixed": np.concatenate(
[
np.arange(3, dtype="int64") // 3,
np.arange(3, dtype="int32") // 3,
np.arange(3, dtype="int16") // 3,
]
),
"col7_bool": [True] * 5 + [False] * 4,
"col8_bool_missing": [
True,
None,
False,
False,
None,
None,
True,
False,
None,
],
"col9_int_missing": [5, 6, np.nan, 2, 1, np.nan, 5, np.nan, np.nan],
"col10_mixed_missing": np.concatenate(
[
np.arange(2, dtype="int64") // 3,
[np.nan],
np.arange(2, dtype="int32") // 3,
[np.nan],
np.arange(2, dtype="int16") // 3,
[np.nan],
]
),
}
)
pandas_df = snowpark_pandas_df.to_pandas()
with SqlCounter(query_count=1 if as_index else 2):
eval_snowpark_pandas_result(
snowpark_pandas_df,
pandas_df,
lambda df: df.groupby(by, as_index=as_index).size(),
)

# DataFrame with __getitem__
with SqlCounter(query_count=1 if as_index else 2):
eval_snowpark_pandas_result(
snowpark_pandas_df,
pandas_df,
lambda df: df.groupby(by, as_index=as_index)["col5_int16"].size(),
)


@sql_count_checker(query_count=0)
def test_error_checking():
s = pd.Series(list("abc") * 4)
with pytest.raises(NotImplementedError):
s.groupby(s).size()
2 changes: 0 additions & 2 deletions tests/unit/modin/test_groupby_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
(lambda se: se.groupby("A").rolling(2), "rolling"),
(lambda se: se.groupby("A").sample(n=1, random_state=1), "sample"),
(lambda se: se.groupby("A").sem(), "sem"),
(lambda se: se.groupby("A").size(), "size"),
(lambda se: se.groupby("A").skew(), "skew"),
(lambda se: se.groupby("A").take(2), "take"),
(lambda se: se.groupby("A").expanding(), "expanding"),
Expand Down Expand Up @@ -91,7 +90,6 @@ def test_series_groupby_unsupported_methods_raises(
(lambda df: df.groupby("A").rolling(2), "rolling"),
(lambda df: df.groupby("A").sample(n=1, random_state=1), "sample"),
(lambda df: df.groupby("A").sem(), "sem"),
(lambda df: df.groupby("A").size(), "size"),
(lambda df: df.groupby("A").skew(), "skew"),
(lambda df: df.groupby("A").take(2), "take"),
(lambda df: df.groupby("A").expanding(), "expanding"),
Expand Down

0 comments on commit 04370fe

Please sign in to comment.