Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
Signed-off-by: Devin Petersohn <[email protected]>
  • Loading branch information
sfc-gh-dpetersohn committed Jun 12, 2024
1 parent 268fdc1 commit 438d5b4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3495,14 +3495,14 @@ def groupby_shift(

def groupby_size(
self,
by,
axis,
groupby_kwargs,
agg_args,
agg_kwargs,
drop=False,
**kwargs
):
by: Any,
axis: int,
groupby_kwargs: dict[str, Any],
agg_args: tuple[Any],
agg_kwargs: dict[str, Any],
drop: bool = False,
**kwargs: dict[str, Any],
) -> "SnowflakeQueryCompiler":
"""
compute groupby with size.
With a dataframe created with following:
Expand Down Expand Up @@ -3548,17 +3548,18 @@ def groupby_size(
by = [by]
# We reset index twice to ensure we perform the count aggregation on the row
# positions (which cannot be null).
result = self.reset_index(drop=True).reset_index(
drop=False, names=MODIN_UNNAMED_SERIES_LABEL
).take_2d_labels(
slice(None), [MODIN_UNNAMED_SERIES_LABEL] + by
).groupby_agg(
by,
"count",
axis,
groupby_kwargs,
(),
{},
result = (
self.reset_index(drop=True)
.reset_index(drop=False, names=MODIN_UNNAMED_SERIES_LABEL)
.take_2d_labels(slice(None), [MODIN_UNNAMED_SERIES_LABEL] + by)
.groupby_agg(
by,
"count",
axis,
groupby_kwargs,
(),
{},
)
)
if not groupby_kwargs.get("as_index", True):
return result.rename(columns_renamer={MODIN_UNNAMED_SERIES_LABEL: "size"})
Expand Down
6 changes: 3 additions & 3 deletions tests/integ/modin/groupby/test_groupby_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.sql_counter import SqlCounter
from tests.integ.modin.utils import eval_snowpark_pandas_result


Expand All @@ -24,8 +24,8 @@
"col9_int_missing",
"col10_mixed_missing",
["col1_grp", "col2_int64"],
["col6_mixed", "col7_bool", "col3_int_identical"]
]
["col6_mixed", "col7_bool", "col3_int_identical"],
],
)
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_size(by, as_index):
Expand Down

0 comments on commit 438d5b4

Please sign in to comment.