diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fc5b7d3776..970224ab235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`. - Added support for `Series.str.__getitem__` (`Series.str[...]`). - Added support for `Series.str.lstrip` and `Series.str.rstrip`. +- Added support for `DataFrameGroupby.size` and `SeriesGroupby.size`. - Added support for `DataFrame.expanding` and `Series.expanding` for aggregations `count`, `sum`, `min`, `max`, `mean`, `std`, and `var` with `axis=0`. - Added support for `DataFrame.rolling` and `Series.rolling` for aggregation `count` with `axis=0`. - Added support for `Series.str.match`. diff --git a/docs/source/modin/groupby.rst b/docs/source/modin/groupby.rst index cefcba4e3e2..7c951ae514f 100644 --- a/docs/source/modin/groupby.rst +++ b/docs/source/modin/groupby.rst @@ -50,6 +50,7 @@ GroupBy DataFrameGroupBy.quantile DataFrameGroupBy.rank DataFrameGroupBy.shift + DataFrameGroupBy.size DataFrameGroupBy.std DataFrameGroupBy.sum DataFrameGroupBy.tail @@ -76,6 +77,7 @@ GroupBy SeriesGroupBy.quantile SeriesGroupBy.rank SeriesGroupBy.shift + SeriesGroupBy.size SeriesGroupBy.std SeriesGroupBy.sum SeriesGroupBy.tail diff --git a/docs/source/modin/supported/groupby_supported.rst b/docs/source/modin/supported/groupby_supported.rst index 91580fb2064..7e42000fd64 100644 --- a/docs/source/modin/supported/groupby_supported.rst +++ b/docs/source/modin/supported/groupby_supported.rst @@ -154,7 +154,7 @@ Computations/descriptive stats | ``shift`` | P | ``Y`` if ``axis = 0``, ``freq`` is None, | | | | ``level`` is None, and ``by`` is in the columns | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``size`` | N | | +| ``size`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``skew`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py index 2ceee75e73b..2ac8c8fe85e 100644 --- a/src/snowflake/snowpark/modin/pandas/groupby.py +++ b/src/snowflake/snowpark/modin/pandas/groupby.py @@ -737,7 +737,21 @@ def all(self, skipna=True): def size(self): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions - ErrorMessage.method_not_implemented_error(name="size", class_="GroupBy") + result = self._wrap_aggregation( + type(self._query_compiler).groupby_size, + numeric_only=False, + ) + if not isinstance(result, Series): + result = result.squeeze(axis=1) + if not self._kwargs.get("as_index") and not isinstance(result, Series): + result = ( + result.rename(columns={MODIN_UNNAMED_SERIES_LABEL: "index"}) + if MODIN_UNNAMED_SERIES_LABEL in result.columns + else result + ) + elif isinstance(self._df, Series): + result.name = self._df.name + return result def sum( self, @@ -1190,6 +1204,10 @@ def is_monotonic_increasing(self): name="is_monotonic_increasing", class_="GroupBy" ) + def size(self): + # TODO: Remove this once SNOW-1478924 is fixed + return super().size().rename(self._df.columns[-1]) + def aggregate( self, func: Optional[AggFuncType] = None, diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 9bcafd5246d..070f0466e88 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -7,6 +7,7 @@ import logging import re import typing +import uuid from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo from typing import Any, Callable, Literal, Optional, Union, get_args @@ -3643,6 +3644,90 @@ def groupby_shift( ) ) + def groupby_size( + self, + by: Any, + axis: int, + groupby_kwargs: dict[str, Any], + agg_args: tuple[Any], + agg_kwargs: dict[str, Any], + drop: bool = False, + **kwargs: dict[str, Any], + ) -> "SnowflakeQueryCompiler": + """ + compute groupby with size. + With a dataframe created with following: + import pandas as pd + + data = [[1,2,3], [1, 5, 6], [2, 5, 8], [2, 6, 9]] + + df = pd.DataFrame(data, columns=["a", "b", "c"], index = ["tuna", "salmon", "catfish", "goldfish"]) + + df + + a b c + + tuna 1 2 3 + salmon 1 5 6 + catfish 2 5 8 + goldfish 2 6 9 + + df.groupby("a").size() + + a + 1 2 + 2 2 + dtype: int64 + + + Args: + by: mapping, series, callable, label, pd.Grouper, BaseQueryCompiler, list of such. + Use this to determine the groups. + axis: 0 (index) or 1 (columns). + groupby_kwargs: dict + keyword arguments passed for the groupby. + agg_args: tuple + The aggregation args, unused in `groupby_size`. + agg_kwargs: dict + The aggregation keyword args, unused in `groupby_size`. + drop: bool + Drop the `by` column, unused in `groupby_size`. + Returns: + SnowflakeQueryCompiler: The result of groupby_size() + """ + level = groupby_kwargs.get("level", None) + is_supported = check_is_groupby_supported_by_snowflake(by, level, axis) + + if not is_supported: + ErrorMessage.not_implemented( + "Snowpark pandas GroupBy.size does not yet support pd.Grouper, axis == 1, by != None and level != None, by containing any non-pandas hashable labels, or unsupported aggregation parameters." + ) + if not is_list_like(by): + by = [by] + positions_col_name = f"__TEMP_POS_NAME_{uuid.uuid4().hex[-6:]}__" + # We reset index twice to ensure we perform the count aggregation on the row + # positions (which cannot be null). We name the column a unique new name to + # avoid collisions. We rename them to their final names at the end. + result = ( + self.reset_index(drop=True) + .reset_index(drop=False, names=positions_col_name) + .take_2d_labels(slice(None), [positions_col_name] + by) + .groupby_agg( + by, + "count", + axis, + groupby_kwargs, + (), + {}, + ) + ) + if not groupby_kwargs.get("as_index", True): + return result.rename(columns_renamer={positions_col_name: "size"}) + else: + return result.rename( + columns_renamer={positions_col_name: MODIN_UNNAMED_SERIES_LABEL} + ) + def groupby_groups( self, by: Any, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py index 8c899615d82..90424e8bd43 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py @@ -1328,6 +1328,48 @@ def all(): pass def size(): + """ + Compute group sizes. + + Returns + ------- + DataFrame or Series + Number of rows in each group as a Series if as_index is True + or a DataFrame if as_index is False. + + Examples + -------- + + >>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]] + >>> df = pd.DataFrame(data, columns=["a", "b", "c"], + ... index=["owl", "toucan", "eagle"]) + >>> df + a b c + owl 1 2 3 + toucan 1 5 6 + eagle 7 8 9 + >>> df.groupby("a").size() + a + 1 2 + 7 1 + dtype: int64 + + For SeriesGroupBy: + + >>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]] + >>> df = pd.DataFrame(data, columns=["a", "b", "c"], + ... index=["owl", "toucan", "eagle"]) + >>> df + a b c + owl 1 2 3 + toucan 1 5 6 + eagle 7 8 9 + >>> df.groupby("a")["b"].size() + a + 1 2 + 7 1 + Name: b, dtype: int64 + """ pass @doc( @@ -1922,6 +1964,51 @@ def nunique(): Series """ + def size(): + """ + Compute group sizes. + + Returns + ------- + DataFrame or Series + Number of rows in each group as a Series if as_index is True + or a DataFrame if as_index is False. + + Examples + -------- + + >>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]] + >>> df = pd.DataFrame(data, columns=["a", "b", "c"], + ... index=["owl", "toucan", "eagle"]) + >>> df + a b c + owl 1 2 3 + toucan 1 5 6 + eagle 7 8 9 + >>> df.groupby("a").size() + a + 1 2 + 7 1 + dtype: int64 + + For SeriesGroupBy: + + >>> data = [[1, 2, 3], [1, 5, 6], [7, 8, 9]] + >>> df = pd.DataFrame(data, columns=["a", "b", "c"], + ... index=["owl", "toucan", "eagle"]) + >>> df + a b c + owl 1 2 3 + toucan 1 5 6 + eagle 7 8 9 + >>> df.groupby("a")["b"].size() + a + 1 2 + 7 1 + Name: b, dtype: int64 + """ + pass + def unique(self): pass diff --git a/tests/integ/modin/groupby/test_groupby_size.py b/tests/integ/modin/groupby/test_groupby_size.py new file mode 100644 index 00000000000..9f336f6619f --- /dev/null +++ b/tests/integ/modin/groupby/test_groupby_size.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import modin.pandas as pd +import numpy as np +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import eval_snowpark_pandas_result + + +@pytest.mark.parametrize( + "by", + [ + "col1_grp", + "col2_int64", + "col3_int_identical", + "col4_int32", + "col6_mixed", + "col7_bool", + "col8_bool_missing", + "col9_int_missing", + "col10_mixed_missing", + ["col1_grp", "col2_int64"], + ["col6_mixed", "col7_bool", "col3_int_identical"], + ], +) +@pytest.mark.parametrize("as_index", [True, False]) +def test_groupby_size(by, as_index): + snowpark_pandas_df = pd.DataFrame( + { + "col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], + "col2_int64": np.arange(9, dtype="int64") // 3, + "col3_int_identical": [2] * 9, + "col4_int32": np.arange(9, dtype="int32") // 4, + "col5_int16": np.arange(9, dtype="int16") // 3, + "col6_mixed": np.concatenate( + [ + np.arange(3, dtype="int64") // 3, + np.arange(3, dtype="int32") // 3, + np.arange(3, dtype="int16") // 3, + ] + ), + "col7_bool": [True] * 5 + [False] * 4, + "col8_bool_missing": [ + True, + None, + False, + False, + None, + None, + True, + False, + None, + ], + "col9_int_missing": [5, 6, np.nan, 2, 1, np.nan, 5, np.nan, np.nan], + "col10_mixed_missing": np.concatenate( + [ + np.arange(2, dtype="int64") // 3, + [np.nan], + np.arange(2, dtype="int32") // 3, + [np.nan], + np.arange(2, dtype="int16") // 3, + [np.nan], + ] + ), + } + ) + pandas_df = snowpark_pandas_df.to_pandas() + with SqlCounter(query_count=1 if as_index else 2): + eval_snowpark_pandas_result( + snowpark_pandas_df, + pandas_df, + lambda df: df.groupby(by, as_index=as_index).size(), + ) + + # DataFrame with __getitem__ + with SqlCounter(query_count=1 if as_index else 2): + eval_snowpark_pandas_result( + snowpark_pandas_df, + pandas_df, + lambda df: df.groupby(by, as_index=as_index)["col5_int16"].size(), + ) + + +@sql_count_checker(query_count=0) +def test_error_checking(): + s = pd.Series(list("abc") * 4) + with pytest.raises(NotImplementedError): + s.groupby(s).size() diff --git a/tests/unit/modin/test_groupby_unsupported.py b/tests/unit/modin/test_groupby_unsupported.py index 8168c0045cc..881bb6d016b 100644 --- a/tests/unit/modin/test_groupby_unsupported.py +++ b/tests/unit/modin/test_groupby_unsupported.py @@ -41,7 +41,6 @@ (lambda se: se.groupby("A").rolling(2), "rolling"), (lambda se: se.groupby("A").sample(n=1, random_state=1), "sample"), (lambda se: se.groupby("A").sem(), "sem"), - (lambda se: se.groupby("A").size(), "size"), (lambda se: se.groupby("A").skew(), "skew"), (lambda se: se.groupby("A").take(2), "take"), (lambda se: se.groupby("A").expanding(), "expanding"), @@ -91,7 +90,6 @@ def test_series_groupby_unsupported_methods_raises( (lambda df: df.groupby("A").rolling(2), "rolling"), (lambda df: df.groupby("A").sample(n=1, random_state=1), "sample"), (lambda df: df.groupby("A").sem(), "sem"), - (lambda df: df.groupby("A").size(), "size"), (lambda df: df.groupby("A").skew(), "skew"), (lambda df: df.groupby("A").take(2), "take"), (lambda df: df.groupby("A").expanding(), "expanding"),