From 2f302122ad8897e2dfecabf82e79de5c28dc7ac3 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 12 Jun 2024 16:44:38 -0500 Subject: [PATCH 1/2] SNOW-1478406: Avoid unnecessary queries on squeeze Signed-off-by: Devin Petersohn --- .../snowpark/modin/pandas/dataframe.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py index 8896f698279..681d2cfb9a6 100644 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -2347,15 +2347,15 @@ def squeeze(self, axis: Axis | None = None): len_columns = self._query_compiler.get_axis_len(1) if axis == 1 and len_columns == 1: return Series(query_compiler=self._query_compiler) - # get_axis_len(0) results in a sql query to count number of rows in current - # dataframe. We should only compute len_index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - return Series(query_compiler=self.T._query_compiler) - else: - return self.copy() + if axis in [0, None]: + # get_axis_len(0) results in a sql query to count number of rows in current + # dataframe. We should only compute len_index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + return Series(query_compiler=self.T._query_compiler) + return self.copy() @dataframe_not_implemented() def stack(self, level=-1, dropna=True): # noqa: PR01, RT01, D200 From 70d4a45fd5b1b44ac259623576d48e6492922b7b Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 12 Jun 2024 19:21:48 -0500 Subject: [PATCH 2/2] Reduce sql count checks Signed-off-by: Devin Petersohn --- tests/integ/modin/frame/test_loc.py | 11 ++--------- tests/integ/modin/frame/test_squeeze.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index 17b73223645..e048426d883 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -140,14 +140,7 @@ def key_type(request): def test_df_loc_get_tuple_key( row, col, str_index_snowpark_pandas_df, str_index_native_df ): - with SqlCounter( - query_count=2 - if is_scalar(row) - or isinstance(row, tuple) - or is_scalar(col) - or isinstance(col, tuple) - else 1 - ): + with SqlCounter(query_count=2 if is_scalar(row) or isinstance(row, tuple) else 1): eval_snowpark_pandas_result( str_index_snowpark_pandas_df, str_index_native_df, @@ -185,7 +178,7 @@ def test_df_loc_get_callable_key( def test_df_loc_get_col_non_boolean_key( key, str_index_snowpark_pandas_df, str_index_native_df ): - with SqlCounter(query_count=2 if is_scalar(key) or isinstance(key, tuple) else 1): + with SqlCounter(query_count=1): eval_snowpark_pandas_result( str_index_snowpark_pandas_df, str_index_native_df, diff --git a/tests/integ/modin/frame/test_squeeze.py b/tests/integ/modin/frame/test_squeeze.py index cc9091f10f4..ea1cb72ebd2 100644 --- a/tests/integ/modin/frame/test_squeeze.py +++ b/tests/integ/modin/frame/test_squeeze.py @@ -7,7 +7,7 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.sql_counter import SqlCounter from tests.integ.modin.utils import eval_snowpark_pandas_result @@ -33,8 +33,10 @@ def test_1d(axis): ) if axis is None: expected_query_count = 3 - else: + elif axis in [0, "index"]: expected_query_count = 2 + else: + expected_query_count = 1 with SqlCounter(query_count=expected_query_count): eval_snowpark_pandas_result( pd.DataFrame({"a": [1], "b": [2], "c": [3]}), @@ -43,13 +45,13 @@ def test_1d(axis): ) -@sql_count_checker(query_count=2) def test_2d(axis): - eval_snowpark_pandas_result( - pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), - native_pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), - lambda df: df.squeeze(axis=axis), - ) + with SqlCounter(query_count=1 if axis in [1, "columns"] else 2): + eval_snowpark_pandas_result( + pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), + native_pd.DataFrame({"A": [1, 2, 3], "B": [2, 3, 4]}), + lambda df: df.squeeze(axis=axis), + ) def test_scalar(axis):