-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Devin Petersohn <[email protected]>
- Loading branch information
1 parent
42ac032
commit 04370fe
Showing
8 changed files
with
286 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
import modin.pandas as pd | ||
import numpy as np | ||
import pytest | ||
|
||
import snowflake.snowpark.modin.plugin # noqa: F401 | ||
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker | ||
from tests.integ.modin.utils import eval_snowpark_pandas_result | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"by", | ||
[ | ||
"col1_grp", | ||
"col2_int64", | ||
"col3_int_identical", | ||
"col4_int32", | ||
"col6_mixed", | ||
"col7_bool", | ||
"col8_bool_missing", | ||
"col9_int_missing", | ||
"col10_mixed_missing", | ||
["col1_grp", "col2_int64"], | ||
["col6_mixed", "col7_bool", "col3_int_identical"], | ||
], | ||
) | ||
@pytest.mark.parametrize("as_index", [True, False]) | ||
def test_groupby_size(by, as_index): | ||
snowpark_pandas_df = pd.DataFrame( | ||
{ | ||
"col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], | ||
"col2_int64": np.arange(9, dtype="int64") // 3, | ||
"col3_int_identical": [2] * 9, | ||
"col4_int32": np.arange(9, dtype="int32") // 4, | ||
"col5_int16": np.arange(9, dtype="int16") // 3, | ||
"col6_mixed": np.concatenate( | ||
[ | ||
np.arange(3, dtype="int64") // 3, | ||
np.arange(3, dtype="int32") // 3, | ||
np.arange(3, dtype="int16") // 3, | ||
] | ||
), | ||
"col7_bool": [True] * 5 + [False] * 4, | ||
"col8_bool_missing": [ | ||
True, | ||
None, | ||
False, | ||
False, | ||
None, | ||
None, | ||
True, | ||
False, | ||
None, | ||
], | ||
"col9_int_missing": [5, 6, np.nan, 2, 1, np.nan, 5, np.nan, np.nan], | ||
"col10_mixed_missing": np.concatenate( | ||
[ | ||
np.arange(2, dtype="int64") // 3, | ||
[np.nan], | ||
np.arange(2, dtype="int32") // 3, | ||
[np.nan], | ||
np.arange(2, dtype="int16") // 3, | ||
[np.nan], | ||
] | ||
), | ||
} | ||
) | ||
pandas_df = snowpark_pandas_df.to_pandas() | ||
with SqlCounter(query_count=1 if as_index else 2): | ||
eval_snowpark_pandas_result( | ||
snowpark_pandas_df, | ||
pandas_df, | ||
lambda df: df.groupby(by, as_index=as_index).size(), | ||
) | ||
|
||
# DataFrame with __getitem__ | ||
with SqlCounter(query_count=1 if as_index else 2): | ||
eval_snowpark_pandas_result( | ||
snowpark_pandas_df, | ||
pandas_df, | ||
lambda df: df.groupby(by, as_index=as_index)["col5_int16"].size(), | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=0) | ||
def test_error_checking(): | ||
s = pd.Series(list("abc") * 4) | ||
with pytest.raises(NotImplementedError): | ||
s.groupby(s).size() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters