From 8862b80fdfcd8a0bcee1f52a7ce6811dd1d75671 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Fri, 16 Aug 2024 17:19:33 -0700 Subject: [PATCH] SNOW-1625377 Raise NotImplementedError for timedelta (#2102) --- CHANGELOG.md | 5 +- .../plugin/_internal/aggregation_utils.py | 8 + .../modin/plugin/_internal/apply_utils.py | 2 + .../modin/plugin/_internal/concat_utils.py | 12 + .../plugin/_internal/cumulative_utils.py | 2 + .../modin/plugin/_internal/cut_utils.py | 2 + .../snowpark/modin/plugin/_internal/frame.py | 66 +++- .../modin/plugin/_internal/generator_utils.py | 2 + .../modin/plugin/_internal/groupby_utils.py | 2 + .../modin/plugin/_internal/indexing_utils.py | 59 +++ .../modin/plugin/_internal/isin_utils.py | 2 + .../modin/plugin/_internal/join_utils.py | 4 + .../modin/plugin/_internal/pivot_utils.py | 10 + .../modin/plugin/_internal/resample_utils.py | 6 + .../modin/plugin/_internal/transpose_utils.py | 2 + .../modin/plugin/_internal/unpivot_utils.py | 4 + .../snowpark/modin/plugin/_internal/utils.py | 8 + .../compiler/snowflake_query_compiler.py | 347 +++++++++++++++++- .../modin/plugin/utils/error_message.py | 6 + tests/integ/modin/frame/test_reset_index.py | 13 +- tests/integ/modin/series/test_cache_result.py | 14 +- tests/integ/modin/series/test_copy.py | 18 +- tests/integ/modin/series/test_shift.py | 18 +- tests/integ/modin/series/test_sort_index.py | 16 +- tests/integ/modin/types/test_timedelta.py | 20 + tests/unit/modin/conftest.py | 3 + tests/unit/modin/test_internal_frame.py | 32 ++ .../modin/test_snowflake_query_compiler.py | 2 + 28 files changed, 663 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05b56da1144..077aa2b7844 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,10 @@ `is_month_start`, `is_month_end`, `is_quarter_start`, `is_quarter_end`, `is_year_start`, `is_year_end` and `is_leap_year`. - Added support for `Resampler.fillna` and `Resampler.bfill`. -- Added limited support for the `Timedelta` type, including creating `Timedelta` columns and `to_pandas`. +- Added limited support for the `Timedelta` type, including + - support for creating `Timedelta` columns and `to_pandas`. + - support `copy`, `cache_result`, `shift`, `sort_index`. + - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. - Added support for `Index.argmax` and `Index.argmin`. - Added support for index's arithmetic and comparison operators. - Added support for `Series.dt.round`. diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 9b88b286d40..16a5c157904 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -652,6 +652,12 @@ def drop_non_numeric_data_columns( col.snowflake_quoted_identifier for col in data_column_to_retain ] + new_data_column_types = [ + type + for id, type in original_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.items() + if id in new_data_column_snowflake_quoted_identifiers + ] + return SnowflakeQueryCompiler( InternalFrame.create( ordered_dataframe=original_frame.ordered_dataframe, @@ -660,6 +666,8 @@ def drop_non_numeric_data_columns( data_column_pandas_index_names=original_frame.data_column_pandas_index_names, index_column_pandas_labels=original_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=original_frame.index_column_snowflake_quoted_identifiers, + data_column_types=new_data_column_types, + index_column_types=original_frame.cached_index_column_snowpark_pandas_types, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 5127cafbc25..ecc9e8a041c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -1257,6 +1257,8 @@ def groupby_apply_create_internal_frame_from_final_ordered_dataframe( + func_result_index_column_pandas_labels, index_column_snowflake_quoted_identifiers=group_quoted_identifiers + func_result_index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py index 95c35343432..1bd123b5acc 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py @@ -98,6 +98,8 @@ def convert_to_single_level_index(frame: InternalFrame, axis: int) -> InternalFr data_column_pandas_index_names=[None], index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) else: WarningMessage.tuples_stored_as_array( @@ -122,6 +124,8 @@ def convert_to_single_level_index(frame: InternalFrame, axis: int) -> InternalFr data_column_pandas_labels=frame.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=None, + index_column_types=None, ) @@ -224,6 +228,8 @@ def union_all( data_column_pandas_index_names=frame1.data_column_pandas_index_names, index_column_pandas_labels=frame1.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame1.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) @@ -262,6 +268,8 @@ def add_key_as_index_columns(frame: InternalFrame, key: Hashable) -> InternalFra data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=index_column_pandas_labels, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) @@ -322,6 +330,8 @@ def _select_columns( data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) @@ -360,4 +370,6 @@ def add_global_ordering_columns(frame: InternalFrame, position: int) -> Internal data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py index ce13b0082f7..8d3935c0199 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py @@ -196,6 +196,8 @@ def get_groupby_cumagg_frame_axis0( index_column_pandas_labels=result_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=result_frame.index_column_snowflake_quoted_identifiers, data_column_pandas_index_names=[None], + data_column_types=result_frame.cached_data_column_snowpark_pandas_types, + index_column_types=result_frame.cached_index_column_snowpark_pandas_types, ) else: return result_frame diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 01001aca696..3a0cf769169 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -310,6 +310,8 @@ def compute_bin_indices( data_column_snowflake_quoted_identifiers=[new_data_identifier], index_column_pandas_labels=value_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=value_index_identifiers, + data_column_types=None, + index_column_types=None, ) return new_frame diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index 93423e4027b..36ab44097e9 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -83,9 +83,19 @@ def _create_snowflake_quoted_identifier_to_snowpark_pandas_type( dict mapping each column's Snowflake quoted identifier to the column's Snowpark pandas type. """ if data_column_types is not None: - assert len(data_column_types) == len(data_column_snowflake_quoted_identifiers) + assert len(data_column_types) == len( + data_column_snowflake_quoted_identifiers + ), ( + f"The length of data_column_types {data_column_types} is different from the length of " + f"data_column_snowflake_quoted_identifiers {data_column_snowflake_quoted_identifiers}" + ) if index_column_types is not None: - assert len(index_column_types) == len(index_column_snowflake_quoted_identifiers) + assert len(index_column_types) == len( + index_column_snowflake_quoted_identifiers + ), ( + f"The length of index_column_types {index_column_types} is different from the length of " + f"index_column_snowflake_quoted_identifiers {index_column_snowflake_quoted_identifiers}" + ) return MappingProxyType( { @@ -164,8 +174,8 @@ def create( data_column_snowflake_quoted_identifiers: list[str], index_column_pandas_labels: list[Hashable], index_column_snowflake_quoted_identifiers: list[str], - data_column_types: Optional[list[Optional[SnowparkPandasType]]] = None, - index_column_types: Optional[list[Optional[SnowparkPandasType]]] = None, + data_column_types: Optional[list[Optional[SnowparkPandasType]]], + index_column_types: Optional[list[Optional[SnowparkPandasType]]], ) -> "InternalFrame": """ Args: @@ -630,7 +640,14 @@ def get_snowflake_identifiers_for_levels(self, levels: list[int]) -> list[str]: def get_snowflake_identifiers_and_pandas_labels_from_levels( self, levels: list[int] - ) -> tuple[list[Hashable], list[str], list[Hashable], list[str]]: + ) -> tuple[ + list[Hashable], + list[str], + list[Optional[SnowparkPandasType]], + list[Hashable], + list[str], + list[Optional[SnowparkPandasType]], + ]: """ Selects snowflake identifiers and pandas labels from index columns in `levels`. Also returns snowflake identifiers and pandas labels not in `levels`. @@ -639,36 +656,45 @@ def get_snowflake_identifiers_and_pandas_labels_from_levels( levels: A list of integers represents levels in pandas Index. Returns: - A tuple contains 4 lists: + A tuple contains 6 lists: 1. The first list contains snowflake identifiers of index columns in `levels`. 2. The second list contains pandas labels of index columns in `levels`. - 3. The third list contains snowflake identifiers of index columns not in `levels`. - 4. The fourth list contains pandas labels of index columns not in `levels`. + 3. The third list contains Snowpark pandas types of index columns in `levels`. + 4. The fourth list contains snowflake identifiers of index columns not in `levels`. + 5. The fifth list contains pandas labels of index columns not in `levels`. + 6. The sixth list contains Snowpark pandas types of index columns not in `levels`. """ index_column_pandas_labels_in_levels = [] index_column_snowflake_quoted_identifiers_in_levels = [] + index_column_types_in_levels = [] index_column_pandas_labels_not_in_levels = [] index_column_snowflake_quoted_identifiers_not_in_levels = [] - for idx, (identifier, label) in enumerate( + index_column_types_not_in_levels = [] + for idx, (identifier, label, type) in enumerate( zip( self.index_column_snowflake_quoted_identifiers, self.index_column_pandas_labels, + self.cached_index_column_snowpark_pandas_types, ) ): if idx in levels: index_column_pandas_labels_in_levels.append(label) index_column_snowflake_quoted_identifiers_in_levels.append(identifier) + index_column_types_in_levels.append(type) else: index_column_pandas_labels_not_in_levels.append(label) index_column_snowflake_quoted_identifiers_not_in_levels.append( identifier ) + index_column_types_not_in_levels.append(type) return ( index_column_pandas_labels_in_levels, index_column_snowflake_quoted_identifiers_in_levels, + index_column_types_in_levels, index_column_pandas_labels_not_in_levels, index_column_snowflake_quoted_identifiers_not_in_levels, + index_column_types_not_in_levels, ) @functools.cached_property @@ -855,8 +881,10 @@ def ensure_row_position_column(self) -> "InternalFrame": data_column_pandas_labels=self.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self.data_column_pandas_index_names, + data_column_types=self.cached_data_column_snowpark_pandas_types, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def ensure_row_count_column(self) -> "InternalFrame": @@ -873,6 +901,8 @@ def ensure_row_count_column(self) -> "InternalFrame": data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + data_column_types=self.cached_data_column_snowpark_pandas_types, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def persist_to_temporary_table(self) -> "InternalFrame": @@ -886,9 +916,11 @@ def persist_to_temporary_table(self) -> "InternalFrame": ordered_dataframe=cache_result(self.ordered_dataframe), data_column_pandas_labels=self.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, + data_column_types=self.cached_data_column_snowpark_pandas_types, data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def append_column( @@ -943,6 +975,8 @@ def append_column( data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + data_column_types=self.cached_data_column_snowpark_pandas_types + [None], + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def project_columns( @@ -981,6 +1015,8 @@ def project_columns( data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def rename_snowflake_identifiers( @@ -1062,11 +1098,11 @@ def get_updated_identifiers(identifiers: list[str]) -> list[str]: self.data_column_snowflake_quoted_identifiers ), data_column_pandas_index_names=self.data_column_pandas_index_names, - data_column_types=self.cached_data_column_snowpark_pandas_types, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=get_updated_identifiers( self.index_column_snowflake_quoted_identifiers ), + data_column_types=self.cached_data_column_snowpark_pandas_types, index_column_types=self.cached_index_column_snowpark_pandas_types, ) @@ -1080,7 +1116,7 @@ def update_snowflake_quoted_identifiers_with_expressions( This function takes a mapping from existing snowflake quoted identifiers to new Snowpark column expressions and points the existing quoted identifiers to the column expressions provided by the mapping. For optimization purposes, - existing expressions are kept as columns. This does not change pandas labels. + existing expressions are kept as columns. This does not change pandas labels and cached Snwopark pandas types. The process involves the following steps: @@ -1170,8 +1206,10 @@ def update_snowflake_quoted_identifiers_with_expressions( data_column_pandas_labels=self.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self.data_column_pandas_index_names, + data_column_types=self.cached_data_column_snowpark_pandas_types, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers, + index_column_types=self.cached_index_column_snowpark_pandas_types, ), existing_id_to_new_id_mapping, ) @@ -1239,6 +1277,8 @@ def select_active_columns(self) -> "InternalFrame": data_column_pandas_labels=self.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self.data_column_pandas_index_names, + data_column_types=self.cached_data_column_snowpark_pandas_types, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def strip_duplicates( @@ -1305,6 +1345,8 @@ def strip_duplicates( data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) def filter( @@ -1325,6 +1367,8 @@ def filter( data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + data_column_types=self.cached_data_column_snowpark_pandas_types, + index_column_types=self.cached_index_column_snowpark_pandas_types, ) def normalize_snowflake_quoted_identifiers_with_pandas_label( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py index bea4387864d..3cc37386108 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py @@ -101,6 +101,8 @@ def _create_qc_from_snowpark_dataframe( index_column_snowflake_quoted_identifiers=[ odf.row_position_snowflake_quoted_identifier ], + data_column_types=None, + index_column_types=None, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py index ab71f314038..09572a16d87 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py @@ -386,6 +386,8 @@ def get_frame_with_groupby_columns_as_index( data_column_pandas_labels=internal_frame.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index 041c5069ccd..f0e33d0b8b8 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -406,6 +406,8 @@ def get_frame_by_row_pos_frame( data_column_pandas_index_names=key.data_column_pandas_index_names, index_column_snowflake_quoted_identifiers=key.index_column_snowflake_quoted_identifiers, index_column_pandas_labels=key.index_column_pandas_labels, + data_column_types=key.cached_data_column_snowpark_pandas_types[1:], + index_column_types=key.cached_index_column_snowpark_pandas_types, ) return _get_frame_by_row_pos_int_frame(internal_frame, key) @@ -453,6 +455,8 @@ def _get_frame_by_row_pos_boolean_frame( index_column_snowflake_quoted_identifiers=result_column_mapper.map_left_quoted_identifiers( internal_frame.index_column_snowflake_quoted_identifiers ), + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -500,6 +504,8 @@ def _get_frame_by_row_pos_int_frame( index_column_snowflake_quoted_identifiers=result_column_mapper.map_right_quoted_identifiers( internal_frame.index_column_snowflake_quoted_identifiers ), + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -553,6 +559,8 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( count_ordered_dataframe.row_position_snowflake_quoted_identifier ], data_column_pandas_index_names=[None], + data_column_types=[None], + index_column_types=[None], ) # cross join the count with the key to append the count column with the key frame. For example: if the @@ -682,6 +690,8 @@ def make_positive(val: int) -> Column: data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -1013,6 +1023,8 @@ def get_frame_by_col_label( data_column_pandas_index_names=new_data_column_pandas_index_names, index_column_pandas_labels=result.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=result.index_column_snowflake_quoted_identifiers, + data_column_types=result.cached_data_column_snowpark_pandas_types, + index_column_types=result.cached_index_column_snowpark_pandas_types, ) return result @@ -1129,10 +1141,19 @@ def get_frame_by_col_pos( selected_columns: list[ColumnOrName] = [] # the snowflake quoted identifiers for the selected Snowpark columns selected_columns_quoted_identifiers: list[str] = [] + selected_columns_types = [] + for col_index in valid_indices: snowflake_quoted_identifier = frame_data_column_quoted_identifiers_list[ col_index ] + + selected_columns_types.append( + internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type[ + snowflake_quoted_identifier + ] + ) + pandas_label = frame_data_column_pandas_labels_list[col_index] if snowflake_quoted_identifier in selected_columns_quoted_identifiers: # if the current column has already been selected, duplicate the column with @@ -1161,6 +1182,8 @@ def get_frame_by_col_pos( data_column_snowflake_quoted_identifiers=selected_columns_quoted_identifiers, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=selected_columns_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -1275,6 +1298,9 @@ def _get_frame_by_row_multiindex_label_tuple( new_index_column_snowflake_quoted_identifiers = ( filtered_frame.index_column_snowflake_quoted_identifiers[levels_to_drop:] ) + new_index_types = filtered_frame.cached_index_column_snowpark_pandas_types[ + levels_to_drop: + ] return InternalFrame.create( ordered_dataframe=filtered_frame.ordered_dataframe, @@ -1283,6 +1309,8 @@ def _get_frame_by_row_multiindex_label_tuple( data_column_pandas_index_names=filtered_frame.data_column_pandas_index_names, index_column_pandas_labels=new_index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers, + data_column_types=filtered_frame.cached_data_column_snowpark_pandas_types, + index_column_types=new_index_types, ) @@ -1541,6 +1569,8 @@ def generate_bound_column( data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -1590,6 +1620,8 @@ def _get_frame_by_row_label_boolean_frame( index_column_snowflake_quoted_identifiers=result_column_mapper.map_left_quoted_identifiers( internal_frame.index_column_snowflake_quoted_identifiers ), + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -1664,6 +1696,8 @@ def _get_frame_by_row_label_non_boolean_frame( index_column_snowflake_quoted_identifiers=result_column_mapper.map_right_quoted_identifiers( internal_frame.index_column_snowflake_quoted_identifiers ), + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -1724,6 +1758,8 @@ def _get_frame_by_row_series_bool( data_column_snowflake_quoted_identifiers=key.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=key.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=key.index_column_snowflake_quoted_identifiers, + data_column_types=key.cached_data_column_snowpark_pandas_types, + index_column_types=key.cached_index_column_snowpark_pandas_types, ) joined_frame, result_column_mapper = join( @@ -1745,6 +1781,8 @@ def _get_frame_by_row_series_bool( index_column_snowflake_quoted_identifiers=result_column_mapper.map_right_quoted_identifiers( internal_frame.index_column_snowflake_quoted_identifiers ), + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) @@ -2759,6 +2797,8 @@ def set_frame_2d_positional( data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) @@ -2909,6 +2949,8 @@ def get_kv_frame_from_index_and_item_frames( + new_item_data_column_snowflake_identifiers, index_column_pandas_labels=kv_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=kv_frame.index_column_snowflake_quoted_identifiers, + data_column_types=kv_frame.cached_data_column_snowpark_pandas_types, + index_column_types=kv_frame.cached_index_column_snowpark_pandas_types, ) return new_kv_frame @@ -2979,6 +3021,8 @@ def get_item_series_as_single_row_frame( item_series_snowflake_quoted_identifiers: list[str] = [] item_series_column_exprs: list[Column] = [] + item_series_data_column_types = [] + for row_position, pandas_label in enumerate(item_series_pandas_labels): new_snowflake_quoted_identifier = ( item_frame.ordered_dataframe.generate_snowflake_quoted_identifiers( @@ -2999,6 +3043,9 @@ def get_item_series_as_single_row_frame( ) item_series_snowflake_quoted_identifiers.append(new_snowflake_quoted_identifier) item_series_column_exprs.append(new_column_expr) + item_series_data_column_types.append( + item.cached_data_column_snowpark_pandas_types[0] + ) item_ordered_dataframe = append_columns( item_frame.ordered_dataframe, @@ -3022,6 +3069,8 @@ def get_item_series_as_single_row_frame( index_column_snowflake_quoted_identifiers=item.index_column_snowflake_quoted_identifiers[ :1 ], + data_column_types=item_series_data_column_types, + index_column_types=item.cached_index_column_snowpark_pandas_types[:1], ) return item @@ -3114,6 +3163,16 @@ def get_row_position_index_from_bool_indexer(index: InternalFrame) -> InternalFr index_column_snowflake_quoted_identifiers=[ index_column_snowflake_quoted_identifier ], + data_column_types=[ + index.snowflake_quoted_identifier_to_snowpark_pandas_type.get( + data_column_snowflake_quoted_identifier, None + ) + ], + index_column_types=[ + index.snowflake_quoted_identifier_to_snowpark_pandas_type.get( + index_column_snowflake_quoted_identifier, None + ) + ], ) return index diff --git a/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py index ab2a56cb47c..26d50a8d53c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py @@ -232,6 +232,8 @@ def compute_isin_with_dataframe( + new_identifiers, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) # local import to avoid circular import diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 10d197f2ce2..2e7c6c91a87 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -370,6 +370,8 @@ def _create_internal_frame_with_join_or_align_result( index_column_pandas_labels=index_column_pandas_labels, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, data_column_pandas_index_names=data_column_pandas_index_names, + data_column_types=None, + index_column_types=None, ) result_column_mapper = JoinOrAlignResultColumnMapper( left_quoted_identifiers_map, @@ -886,6 +888,8 @@ def _reorder_index_columns( data_column_pandas_labels=frame.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=None, + index_column_types=None, ) else: return frame diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 36890be82e9..861fd7db1fc 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -257,6 +257,8 @@ def pivot_helper( data_column_snowflake_quoted_identifiers=[], index_column_pandas_labels=index, index_column_snowflake_quoted_identifiers=groupby_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) data_column_pandas_labels: list[Hashable] = [] data_column_snowflake_quoted_identifiers: list[str] = [] @@ -482,6 +484,8 @@ def pivot_helper( data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=index, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) @@ -1285,6 +1289,8 @@ def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( data_column_snowflake_quoted_identifiers=margins_frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=margins_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=margins_frame.index_column_snowflake_quoted_identifiers, + data_column_types=margins_frame.cached_data_column_snowpark_pandas_types, + index_column_types=margins_frame.cached_index_column_snowpark_pandas_types, ) # Need to create a QueryCompiler for the margins frame, but SnowflakeQueryCompiler is not present in this scope @@ -1691,6 +1697,8 @@ def expand_pivot_result_with_pivot_table_margins( data_column_pandas_index_names=pivoted_frame.data_column_pandas_index_names, index_column_pandas_labels=pivoted_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=pivoted_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( @@ -1758,6 +1766,8 @@ def expand_pivot_result_with_pivot_table_margins( index_column_snowflake_quoted_identifiers=margin_row_df_identifiers[ 0 : len(groupby_snowflake_quoted_identifiers) ], + data_column_types=None, + index_column_types=None, ) single_row_qc = SnowflakeQueryCompiler(margin_row_frame) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index d09444cbc32..4e4da1bb499 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -511,6 +511,8 @@ def get_expected_resample_bins_frame( index_column_pandas_labels=[RESAMPLE_INDEX_LABEL], index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, data_column_pandas_index_names=[None], + data_column_types=None, + index_column_types=None, ) @@ -615,6 +617,8 @@ def fill_missing_resample_bins_for_frame( index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=joined_frame.index_column_snowflake_quoted_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) @@ -780,4 +784,6 @@ def perform_asof_join_on_frame( left_timecol_snowflake_quoted_identifier ], data_column_pandas_index_names=referenced_frame.data_column_pandas_index_names, + data_column_types=referenced_frame.cached_data_column_snowpark_pandas_types, + index_column_types=referenced_frame.cached_index_column_snowpark_pandas_types, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/transpose_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/transpose_utils.py index e2fa21fab07..b7204edb619 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/transpose_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/transpose_utils.py @@ -284,6 +284,8 @@ def clean_up_transpose_result_index_and_labels( data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=new_index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) # Rename the data column snowflake quoted identifiers to be closer to pandas labels, normalizing names diff --git a/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py index 3c1bd87bdef..0fa497b8abf 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py @@ -660,6 +660,8 @@ def clean_up_unpivot( data_column_snowflake_quoted_identifiers=final_snowflake_qouted_identfiers, index_column_pandas_labels=index_column_pandas_names, index_column_snowflake_quoted_identifiers=index_column_quoted_names, + data_column_types=None, + index_column_types=None, ) # Rename the data column snowflake quoted identifiers to be closer to pandas labels, normalizing names @@ -879,6 +881,8 @@ def _simple_unpivot( index_column_snowflake_quoted_identifiers=[ ordered_dataframe.row_position_snowflake_quoted_identifier ], + data_column_types=None, + index_column_types=None, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 0b87119e8f5..2c641a91410 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -1765,6 +1765,7 @@ def create_frame_with_data_columns( new_frame_data_column_pandas_labels = [] new_frame_data_column_snowflake_quoted_identifier = [] + new_frame_data_column_types = [] data_column_label_to_snowflake_quoted_identifier = { data_column_pandas_label: data_column_snowflake_quoted_identifier @@ -1786,6 +1787,11 @@ def create_frame_with_data_columns( new_frame_data_column_snowflake_quoted_identifier.append( snowflake_quoted_identifier ) + new_frame_data_column_types.append( + frame.snowflake_quoted_identifier_to_snowpark_pandas_type.get( + snowflake_quoted_identifier, None + ) + ) return InternalFrame.create( ordered_dataframe=frame.ordered_dataframe, @@ -1794,6 +1800,8 @@ def create_frame_with_data_columns( data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=new_frame_data_column_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 30ad5b2713b..108b594faf6 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3,6 +3,7 @@ # import calendar import functools +import inspect import itertools import json import logging @@ -259,6 +260,9 @@ rule_to_snowflake_width_and_slice_unit, validate_resample_supported_by_snowflake, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.timestamp_utils import ( VALID_TO_DATETIME_DF_KEYS, DateTimeOrigin, @@ -385,6 +389,60 @@ def __init__(self, frame: InternalFrame) -> None: # Copying and modifying self.snowpark_pandas_api_calls is taken care of in telemetry decorators self.snowpark_pandas_api_calls: list = [] + def _raise_not_implemented_error_for_timedelta(self) -> None: + """Raise NotImplementedError for SnowflakeQueryCompiler methods which does not support timedelta yet.""" + for ( + val + ) in ( + self._modin_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + method = "Unknown method" + if isinstance(val, TimedeltaType): + try: + method = inspect.currentframe().f_back.f_back.f_code.co_name # type: ignore[union-attr] + except Exception: + pass + finally: + ErrorMessage.not_implemented_for_timedelta(method) + + def snowpark_pandas_type_immutable_check(func: Callable) -> Any: + """The decorator to check on SnowflakeQueryCompiler methods which return a new SnowflakeQueryCompiler. + It verifies the cached Snowpark pandas types should not be changed. + """ + + def check_type(input: List, output: List) -> None: + assert len(input) == len( + output + ), "self frame and output frame have different number of columns" + + for lt, rt in zip(input, output): + assert ( + lt == rt + ), f"one column's Snowpark pandas type has been changed from {lt} to {rt}" + + @functools.wraps(func) + def wrap(*args, **kwargs): # type: ignore + self_qc = args[0] + output_qc = func(*args, **kwargs) + assert isinstance(self_qc, SnowflakeQueryCompiler) and isinstance( + output_qc, SnowflakeQueryCompiler + ), ( + "immutable_snowpark_pandas_type_check only works with SnowflakeQueryCompiler member methods with " + "SnowflakeQueryCompiler as the return result" + ) + check_type( + self_qc._modin_frame.cached_index_column_snowpark_pandas_types, + output_qc._modin_frame.cached_index_column_snowpark_pandas_types, + ) + check_type( + self_qc._modin_frame.cached_data_column_snowpark_pandas_types, + output_qc._modin_frame.cached_data_column_snowpark_pandas_types, + ) + + return output_qc + + return wrap + @property def dtypes(self) -> native_pd.Series: """ @@ -641,6 +699,7 @@ def from_date_range( dt_series = dt_series[dt_series != end] return dt_series._query_compiler + @snowpark_pandas_type_immutable_check def copy(self) -> "SnowflakeQueryCompiler": """ Make a copy of this object. @@ -851,6 +910,8 @@ def find_snowflake_quoted_identifier(pandas_columns: list[str]) -> list[str]: data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=index_column_pandas_labels, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, + data_column_types=None, # from_snowflake won't provide any client side type info + index_column_types=None, ) ) @@ -1201,6 +1262,8 @@ def to_csv_with_snowflake(self, **kwargs: Any) -> None: Args: **kwargs: to_csv arguments. """ + self._raise_not_implemented_error_for_timedelta() + # Raise not implemented error for unsupported parameters. unsupported_params = [ "float_format", @@ -1256,6 +1319,8 @@ def to_snowflake( index_label: Optional[IndexLabel] = None, table_type: Literal["", "temp", "temporary", "transient"] = "", ) -> None: + self._raise_not_implemented_error_for_timedelta() + if if_exists not in ("fail", "replace", "append"): # Same error message as native pandas. raise ValueError(f"'{if_exists}' is not valid for if_exists") @@ -1299,11 +1364,13 @@ def to_snowpark( For details, please see comment in _to_snowpark_dataframe_of_pandas_dataframe. """ + self._raise_not_implemented_error_for_timedelta() return self._to_snowpark_dataframe_from_snowpark_pandas_dataframe( index, index_label ) + @snowpark_pandas_type_immutable_check def cache_result(self) -> "SnowflakeQueryCompiler": """ Returns a materialized view of this QueryCompiler. @@ -1321,6 +1388,7 @@ def columns(self) -> native_pd.Index: # TODO SNOW-837664: add more tests for df.columns return self._modin_frame.data_columns_index + @snowpark_pandas_type_immutable_check def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": """ Set pandas column labels with the new column labels @@ -1367,9 +1435,9 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": data_column_pandas_labels=new_pandas_labels.tolist(), data_column_pandas_index_names=new_pandas_labels.names, data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, - data_column_types=renamed_frame.cached_data_column_snowpark_pandas_types, index_column_pandas_labels=renamed_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=renamed_frame.index_column_snowflake_quoted_identifiers, + data_column_types=renamed_frame.cached_data_column_snowpark_pandas_types, index_column_types=renamed_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -1678,6 +1746,8 @@ def set_index_from_series( Returns: The new SnowflakeQueryCompiler after the set_index operation """ + self._raise_not_implemented_error_for_timedelta() + assert ( len(key._modin_frame.data_column_pandas_labels) == 1 ), "need to be a series" @@ -1724,6 +1794,8 @@ def set_index_from_series( ), index_column_pandas_labels=new_index_labels, index_column_snowflake_quoted_identifiers=new_index_ids, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -1797,6 +1869,8 @@ def _binary_op_list_like_rhs_axis_0( """ from snowflake.snowpark.modin.pandas.series import Series + self._raise_not_implemented_error_for_timedelta() + # Step 1: Convert other to a Series and join on the row position with self. other_qc = Series(other)._query_compiler self_frame = self._modin_frame.ensure_row_position_column() @@ -1843,6 +1917,8 @@ def _binary_op_list_like_rhs_axis_0( data_column_pandas_index_names=new_frame.data_column_pandas_index_names, index_column_pandas_labels=new_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_frame) @@ -1940,6 +2016,8 @@ def binary_op( from snowflake.snowpark.modin.pandas.series import Series from snowflake.snowpark.modin.pandas.utils import is_scalar + self._raise_not_implemented_error_for_timedelta() + # fail explicitly for unsupported scenarios if level is not None: # TODO SNOW-862668: binary operations with level @@ -2213,6 +2291,8 @@ def _add_columns_for_monotonicity_checks( and the Snowflake quoted identifiers for the monotonically increasing and monotonically decreasing columns (in that order). """ + self._raise_not_implemented_error_for_timedelta() + modin_frame = self._modin_frame modin_frame = modin_frame.ensure_row_position_column() row_position_column = modin_frame.row_position_snowflake_quoted_identifier @@ -2256,6 +2336,8 @@ def _add_columns_for_monotonicity_checks( data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return ( SnowflakeQueryCompiler(modin_frame), @@ -2289,6 +2371,8 @@ def _reindex_axis_0( SnowflakeQueryCompiler QueryCompiler with aligned axis. """ + self._raise_not_implemented_error_for_timedelta() + new_index_qc = pd.Series(labels)._query_compiler new_index_modin_frame = new_index_qc._modin_frame modin_frame = self._modin_frame @@ -2382,6 +2466,8 @@ def _reindex_axis_0( index_column_snowflake_quoted_identifiers=result_frame_column_mapper.map_left_quoted_identifiers( new_index_modin_frame.data_column_snowflake_quoted_identifiers ), + data_column_types=None, + index_column_types=None, ) new_qc = SnowflakeQueryCompiler(new_modin_frame) if method or fill_value is not np.nan: @@ -2443,6 +2529,8 @@ def _reindex_axis_0( data_column_pandas_index_names=new_qc._modin_frame.data_column_pandas_index_names, index_column_pandas_labels=new_qc._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_qc._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) ) if fill_value is not np.nan: @@ -2467,6 +2555,8 @@ def _reindex_axis_0( data_column_pandas_index_names=new_qc._modin_frame.data_column_pandas_index_names, index_column_pandas_labels=new_qc._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_qc._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) ) if is_index: @@ -2495,6 +2585,8 @@ def _reindex_axis_0( data_column_pandas_index_names=modin_frame.data_column_pandas_index_names, index_column_pandas_labels=modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) ) materialized_frame = new_qc._modin_frame.ordered_dataframe.select( @@ -2559,6 +2651,8 @@ def _reindex_axis_1( SnowflakeQueryCompiler QueryCompiler with aligned axis. """ + self._raise_not_implemented_error_for_timedelta() + method = kwargs.get("method", None) level = kwargs.get("level", None) limit = kwargs.get("limit", None) @@ -2593,6 +2687,8 @@ def _reindex_axis_1( data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) new_qc = SnowflakeQueryCompiler(new_modin_frame) ordered_columns = sorted(data_column_pandas_labels) @@ -2727,8 +2823,10 @@ def reset_index( ( index_column_pandas_labels_moved, index_column_snowflake_quoted_identifiers_moved, + index_column_types_moved, index_column_pandas_labels_remained, index_column_snowflake_quoted_identifiers_remained, + index_column_types_remained, ) = self._modin_frame.get_snowflake_identifiers_and_pandas_labels_from_levels( levels_to_be_reset ) @@ -2755,6 +2853,7 @@ def reset_index( index_column_snowflake_quoted_identifiers_remained = [ index_column_snowflake_quoted_identifier ] + index_column_types_remained = [None] # Do not drop existing index columns and move them to data columns. if not drop: @@ -2829,19 +2928,30 @@ def reset_index( index_column_snowflake_quoted_identifiers_moved + self._modin_frame.data_column_snowflake_quoted_identifiers ) + + data_column_types = ( + index_column_types_moved + + self._modin_frame.cached_data_column_snowpark_pandas_types + ) + else: data_column_pandas_labels = self._modin_frame.data_column_pandas_labels data_column_snowflake_quoted_identifiers = ( self._modin_frame.data_column_snowflake_quoted_identifiers ) + data_column_types = ( + self._modin_frame.cached_data_column_snowpark_pandas_types + ) internal_frame = InternalFrame.create( ordered_dataframe=ordered_dataframe, data_column_pandas_labels=data_column_pandas_labels, data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, + data_column_types=data_column_types, index_column_pandas_labels=index_column_pandas_labels_remained, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers_remained, + index_column_types=index_column_types_remained, ) return SnowflakeQueryCompiler(internal_frame) @@ -3078,6 +3188,12 @@ def sort_rows_by_column_values( data_column_snowflake_quoted_identifiers.append( internal_frame.row_position_snowflake_quoted_identifier ) + data_column_types = [ + internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.get( + id, None + ) + for id in data_column_snowflake_quoted_identifiers + ] sorted_frame = InternalFrame.create( ordered_dataframe=ordered_dataframe, data_column_pandas_labels=data_column_pandas_labels, @@ -3085,6 +3201,8 @@ def sort_rows_by_column_values( data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=data_column_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) sorted_qc = SnowflakeQueryCompiler(sorted_frame) @@ -3112,6 +3230,8 @@ def validate_groupby( KeyError if a hashable label in by (groupby items) can not be found in the current dataframe ValueError if more than one column can be found for the groupby item """ + self._raise_not_implemented_error_for_timedelta() + validate_groupby_columns(self, by, axis, level) def groupby_ngroups( @@ -3120,6 +3240,8 @@ def groupby_ngroups( axis: int, groupby_kwargs: dict[str, Any], ) -> int: + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) dropna = groupby_kwargs.get("dropna", True) @@ -3205,6 +3327,8 @@ def groupby_agg( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) if agg_func in ["head", "tail"]: @@ -3443,6 +3567,8 @@ def convert_func_to_agg_func_info( data_column_snowflake_quoted_identifiers=new_data_column_quoted_identifier, index_column_pandas_labels=new_index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_index_column_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) ) @@ -3482,6 +3608,8 @@ def groupby_apply( ------- A query compiler with the result. """ + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) if not check_is_groupby_supported_by_snowflake(by, level, axis): ErrorMessage.not_implemented( @@ -3946,6 +4074,8 @@ def groupby_first( Returns: SnowflakeQueryCompiler: The result of groupby_first() """ + self._raise_not_implemented_error_for_timedelta() + return self._groupby_first_last( "first", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs ) @@ -3979,6 +4109,8 @@ def groupby_last( Returns: SnowflakeQueryCompiler: The result of groupby_last() """ + self._raise_not_implemented_error_for_timedelta() + return self._groupby_first_last( "last", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs ) @@ -4053,6 +4185,8 @@ def groupby_rank( 5 1 6 2 """ + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) dropna = groupby_kwargs.get("dropna", True) @@ -4225,6 +4359,8 @@ def groupby_shift( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + # TODO: handle cases where the fill_value has a different type from # the column. SNOW-990325 deals with fillna that has a similar problem. @@ -4339,6 +4475,8 @@ def groupby_shift( data_column_snowflake_quoted_identifiers=snowflake_quoted_identifiers, index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) ) @@ -4370,6 +4508,8 @@ def groupby_get_group( Returns: SnowflakeQueryCompiler: The result of groupby_get_group(). """ + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) is_supported = check_is_groupby_supported_by_snowflake(by, level, axis) if not is_supported: # pragma: no cover @@ -4437,6 +4577,8 @@ def groupby_size( Returns: SnowflakeQueryCompiler: The result of groupby_size() """ + self._raise_not_implemented_error_for_timedelta() + level = groupby_kwargs.get("level", None) is_supported = check_is_groupby_supported_by_snowflake(by, level, axis) if not is_supported: @@ -4517,6 +4659,8 @@ def groupby_groups( 4 5 2 4 5 0 8 9 0 8 """ + self._raise_not_implemented_error_for_timedelta() + original_index_names = self.get_index_names() frame = self._modin_frame index_data_columns = [] @@ -4624,6 +4768,8 @@ def groupby_indices( Returns: dict: a map from group keys to row labels. """ + self._raise_not_implemented_error_for_timedelta() + frame = self._modin_frame.ensure_row_position_column() return dict( # .indices aggregates row position numbers, so we add a row @@ -4669,6 +4815,8 @@ def groupby_cumcount( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + return SnowflakeQueryCompiler( get_groupby_cumagg_frame_axis0( self, @@ -4704,6 +4852,8 @@ def groupby_cummax( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + return SnowflakeQueryCompiler( get_groupby_cumagg_frame_axis0( self, @@ -4738,6 +4888,8 @@ def groupby_cummin( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + return SnowflakeQueryCompiler( get_groupby_cumagg_frame_axis0( self, @@ -4769,6 +4921,8 @@ def groupby_cumsum( Returns: SnowflakeQueryCompiler: with a newly constructed internal dataframe """ + self._raise_not_implemented_error_for_timedelta() + return SnowflakeQueryCompiler( get_groupby_cumagg_frame_axis0( self, @@ -4791,6 +4945,8 @@ def groupby_nunique( drop: bool = False, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + # We have to override the Modin version of this function because our groupby frontend passes the # ignored numeric_only argument to this query compiler method, and BaseQueryCompiler # does not have **kwargs. @@ -4814,6 +4970,8 @@ def groupby_any( drop: bool = False, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + # We have to override the Modin version of this function because our groupby frontend passes the # ignored numeric_only argument to this query compiler method, and BaseQueryCompiler # does not have **kwargs. @@ -4837,6 +4995,8 @@ def groupby_all( drop: bool = False, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + # We have to override the Modin version of this function because our groupby frontend passes the # ignored numeric_only argument to this query compiler method, and BaseQueryCompiler # does not have **kwargs. @@ -4856,6 +5016,8 @@ def _get_dummies_helper( prefix: Hashable, prefix_sep: str, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + dummy_column_name = random_name_for_temp_object(TempObjectType.COLUMN) # We need to add a column that will help us differentiate between identical # rows, so that we do not have aggregations happen on duplicate rows. @@ -5039,6 +5201,8 @@ def _get_dummies_helper( data_column_snowflake_quoted_identifiers=frame_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=query_compiler._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=query_compiler._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) if len(new_col_map) > 0: @@ -5085,6 +5249,8 @@ def get_dummies( 1 2 0 1 1 0 0 2 3 1 0 0 0 1 """ + self._raise_not_implemented_error_for_timedelta() + if dummy_na is True or drop_first is True or dtype is not None: ErrorMessage.not_implemented( "get_dummies with non-default dummy_na, drop_first, and dtype parameters" @@ -5153,6 +5319,8 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ + self._raise_not_implemented_error_for_timedelta() + numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5374,6 +5542,8 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -5587,6 +5757,7 @@ def insert( Returns: A new SnowflakeQueryCompiler instance with new column. """ + self._raise_not_implemented_error_for_timedelta() if not isinstance(value, SnowflakeQueryCompiler): # Scalar value @@ -5686,6 +5857,8 @@ def move_last_element(arr: list, index: int) -> None: data_column_pandas_index_names=new_internal_frame.data_column_pandas_index_names, index_column_pandas_labels=new_internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -5713,6 +5886,8 @@ def set_index_from_columns( Returns: A new QueryCompiler instance with updated index. """ + self._raise_not_implemented_error_for_timedelta() + index_column_pandas_labels = keys index_column_snowflake_quoted_identifiers = [] for ( @@ -5781,9 +5956,12 @@ def set_index_from_columns( data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, data_column_pandas_labels=data_column_pandas_labels, data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(frame) + @snowpark_pandas_type_immutable_check def rename( self, *, @@ -5883,6 +6061,10 @@ def rename( data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ + :-1 + ], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) new_qc = SnowflakeQueryCompiler(internal_frame) @@ -6023,6 +6205,8 @@ def dataframe_to_datetime( data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=[None], + index_column_types=[None], ) ) @@ -6157,6 +6341,8 @@ def concat( NOTE: Original column level names are lost and result column index has only one level. """ + self._raise_not_implemented_error_for_timedelta() + if levels is not None: raise NotImplementedError( "Snowpark pandas doesn't support 'levels' argument in concat API" @@ -6346,6 +6532,8 @@ def cumsum( Returns: SnowflakeQueryCompiler instance with cumulative sum of Series or DataFrame. """ + self._raise_not_implemented_error_for_timedelta() + if axis == 1: ErrorMessage.not_implemented("cumsum with axis=1 is not supported yet") @@ -6373,6 +6561,8 @@ def cummin( Returns: SnowflakeQueryCompiler instance with cumulative min of Series or DataFrame. """ + self._raise_not_implemented_error_for_timedelta() + if axis == 1: ErrorMessage.not_implemented("cummin with axis=1 is not supported yet") @@ -6400,6 +6590,8 @@ def cummax( Returns: SnowflakeQueryCompiler instance with cumulative max of Series or DataFrame. """ + self._raise_not_implemented_error_for_timedelta() + if axis == 1: ErrorMessage.not_implemented("cummax with axis=1 is not supported yet") @@ -6439,6 +6631,8 @@ def melt( Notes: melt does not yet handle multiindex or ignore index """ + self._raise_not_implemented_error_for_timedelta() + if col_level is not None: raise NotImplementedError( "Snowpark Pandas doesn't support 'col_level' argument in melt API" @@ -6541,6 +6735,8 @@ def merge( Returns: SnowflakeQueryCompiler instance with merged result. """ + self._raise_not_implemented_error_for_timedelta() + if validate: ErrorMessage.not_implemented( "Snowpark pandas merge API doesn't yet support 'validate' parameter" @@ -6771,6 +6967,8 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( Returns: SnowflakeQueryCompiler which may be Series or DataFrame representing result of .apply(axis=1) """ + self._raise_not_implemented_error_for_timedelta() + # Process using general approach via UDTF + dynamic pivot to handle column expansion case. # Overwrite partition-size with kwargs arg @@ -7020,6 +7218,8 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( data_column_snowflake_quoted_identifiers=renamed_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=new_internal_df.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_internal_df.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -7060,6 +7260,7 @@ def _apply_udf_row_wise_and_reduce_to_series_along_axis_1( Returns: SnowflakeQueryCompiler representing a Series holding the result of apply(func, axis=1). """ + self._raise_not_implemented_error_for_timedelta() # extract index columns and types, which are passed as first columns to UDF. type_map = self._modin_frame.quoted_identifier_to_snowflake_type() @@ -7137,6 +7338,8 @@ def vectorized_udf(df: pandas.DataFrame) -> pandas.Series: # pragma: no cover data_column_snowflake_quoted_identifiers=[new_identifier], index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_frame) @@ -7175,6 +7378,7 @@ def apply( **kwargs : dict Keyword arguments to pass to `func`. """ + self._raise_not_implemented_error_for_timedelta() # axis=0 is not supported, raise error. if axis == 0: @@ -7268,6 +7472,8 @@ def applymap( *args : iterable **kwargs : dict """ + self._raise_not_implemented_error_for_timedelta() + # Currently, NULL values are always passed into the udtf even if strict=True, # which is a bug on the server side SNOW-880105. # The fix will not land soon, so we are going to raise not implemented error for now. @@ -7308,6 +7514,8 @@ def map( na_action: Optional[Literal["ignore"]] = None, ) -> "SnowflakeQueryCompiler": """This method will only be called from Series.""" + self._raise_not_implemented_error_for_timedelta() + # TODO SNOW-801847: support series.map when arg is a dict/series # Currently, NULL values are always passed into the udtf even if strict=True, # which is a bug on the server side SNOW-880105. @@ -7338,6 +7546,8 @@ def apply_on_series( **kwargs : dict Keyword arguments to pass to `func`. """ + self._raise_not_implemented_error_for_timedelta() + assert self.is_series_like() # TODO SNOW-856682: support other types (str, list, dict) of func @@ -7392,6 +7602,8 @@ def pivot( ------- SnowflakeQueryCompiler """ + self._raise_not_implemented_error_for_timedelta() + # Call pivot_table which is a more generalized version of pivot with `min` aggregation # Note we differ from pandas by not checking for duplicates and raising a ValueError as that would require an eager query return self.pivot_table( @@ -7420,6 +7632,8 @@ def pivot_table( observed: bool, sort: bool, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + """ Create a spreadsheet-style pivot table from underlying data. @@ -7769,6 +7983,8 @@ def take_2d_positional( BaseQueryCompiler New masked QueryCompiler. """ + self._raise_not_implemented_error_for_timedelta() + # TODO: SNOW-884220 support multiindex # index can only be a query compiler or slice object assert isinstance(index, (SnowflakeQueryCompiler, slice)) @@ -7923,6 +8139,8 @@ def make_nunique(identifier: str, dropna: bool) -> SnowparkColumn: data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, index_column_pandas_labels=[INDEX_LABEL], index_column_snowflake_quoted_identifiers=[new_index_identifier], + data_column_types=None, # no snowpark pandas type for nunique + index_column_types=None, # no snowpark pandas type for nunique ) return SnowflakeQueryCompiler(frame) @@ -8085,6 +8303,8 @@ def take_2d_labels( ------- SnowflakeQueryCompiler """ + self._raise_not_implemented_error_for_timedelta() + if self._modin_frame.is_multiindex(axis=0) and ( is_scalar(index) or isinstance(index, tuple) ): @@ -8191,6 +8411,8 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ + self._raise_not_implemented_error_for_timedelta() + frame = self._modin_frame # Handle case where the dataframe has empty columns. @@ -8347,6 +8569,7 @@ def transpose(self) -> "SnowflakeQueryCompiler": # STEP 1) Construct a temporary index column that contains the original index with position. # STEP 2) Perform an unpivot which flattens the original data columns into a single name and value rows # grouped by the temporary transpose index column. + self._raise_not_implemented_error_for_timedelta() unpivot_result = prepare_and_unpivot_for_transpose( frame, self, is_single_row=False @@ -8855,6 +9078,10 @@ def case_when(self, caselist: List[tuple]) -> "SnowflakeQueryCompiler": data_column_snowflake_quoted_identifiers=joined_frame.data_column_snowflake_quoted_identifiers[ :1 ], + data_column_types=joined_frame.cached_index_column_snowpark_pandas_types[ + :1 + ], + index_column_types=joined_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_frame) @@ -9319,6 +9546,8 @@ def where( data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=joined_frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=joined_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_frame) @@ -9694,6 +9923,8 @@ def dropna( data_column_snowflake_quoted_identifiers=self._modin_frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types, + index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -9728,6 +9959,8 @@ def set_index_names( data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=names, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(frame) @@ -9986,6 +10219,8 @@ def drop( data_column_pandas_labels=data_column_labels, data_column_snowflake_quoted_identifiers=data_column_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=None, + index_column_types=None, ) frame = frame.select_active_columns() @@ -10050,6 +10285,8 @@ def _drop_axis_0( data_column_pandas_labels=frame.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(frame) @@ -10724,6 +10961,8 @@ def rank( data_column_pandas_labels=new_frame.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=new_frame.data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=new_frame.data_column_pandas_index_names, + data_column_types=new_frame.cached_data_column_snowpark_pandas_types, + index_column_types=new_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_frame) @@ -11095,6 +11334,8 @@ def _value_counts_groupby( dropna : bool Don't include counts of NaN. """ + self._raise_not_implemented_error_for_timedelta() + # validate whether by is valid (e.g., contains duplicates or non-existing labels) self.validate_groupby(by=by, axis=0, level=None) @@ -11154,6 +11395,8 @@ def _value_counts_groupby( data_column_pandas_labels=[MODIN_UNNAMED_SERIES_LABEL], data_column_snowflake_quoted_identifiers=[count_identifier], data_column_pandas_index_names=query_compiler._modin_frame.data_column_pandas_index_names, + data_column_types=[None], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -11214,6 +11457,9 @@ def build_repr_df( data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types + + [None], + index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types, ) row_count_identifier = ( @@ -11398,6 +11644,8 @@ def quantiles_along_axis0( index_column_snowflake_quoted_identifiers=[ index_column_snowflake_quoted_identifier ], + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, + index_column_types=[None], ) ) @@ -11481,6 +11729,8 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ + self._raise_not_implemented_error_for_timedelta() + assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -11545,6 +11795,8 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], + data_column_types=None, + index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate # the logic here so we don't have to mess with set_index. @@ -11887,6 +12139,8 @@ def count_freqs( data_column_snowflake_quoted_identifiers=top_freq_identifiers, index_column_pandas_labels=new_index_labels, index_column_snowflake_quoted_identifiers=new_index_identifiers, + data_column_types=[None, None], + index_column_types=[None] * len(new_index_labels), ) ).transpose() query_compilers_to_concat.extend([unique_qc, top_freq_qc]) @@ -12159,6 +12413,8 @@ def sample( data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) ) if ignore_index: @@ -12223,6 +12479,8 @@ def rolling_count( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + return self._window_agg( window_func=WindowFunction.ROLLING, agg_func="count", @@ -12240,6 +12498,8 @@ def rolling_sum( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_sum", engine, engine_kwargs ) @@ -12260,6 +12520,8 @@ def rolling_mean( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_mean", engine, engine_kwargs ) @@ -12292,6 +12554,8 @@ def rolling_var( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_var", engine, engine_kwargs ) @@ -12313,6 +12577,8 @@ def rolling_std( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_var", engine, engine_kwargs ) @@ -12333,6 +12599,8 @@ def rolling_min( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_min", engine, engine_kwargs ) @@ -12353,6 +12621,8 @@ def rolling_max( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_max", engine, engine_kwargs ) @@ -12448,6 +12718,8 @@ def rolling_sem( *args: Any, **kwargs: Any, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + return self._window_agg( window_func=WindowFunction.ROLLING, agg_func="sem", @@ -12592,6 +12864,8 @@ def expanding_count( expanding_kwargs: dict, numeric_only: bool = False, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + return self._window_agg( window_func=WindowFunction.EXPANDING, agg_func="count", @@ -12607,6 +12881,8 @@ def expanding_sum( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "expanding_sum", engine, engine_kwargs ) @@ -12625,6 +12901,8 @@ def expanding_mean( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "expanding_mean", engine, engine_kwargs ) @@ -12654,6 +12932,8 @@ def expanding_var( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_var", engine, engine_kwargs ) @@ -12673,6 +12953,8 @@ def expanding_std( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "rolling_std", engine, engine_kwargs ) @@ -12691,6 +12973,8 @@ def expanding_min( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "expanding_min", engine, engine_kwargs ) @@ -12709,6 +12993,8 @@ def expanding_max( engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + WarningMessage.warning_if_engine_args_is_set( "expanding_max", engine, engine_kwargs ) @@ -12797,6 +13083,8 @@ def expanding_sem( ddof: int = 1, numeric_only: bool = False, ) -> "SnowflakeQueryCompiler": + self._raise_not_implemented_error_for_timedelta() + return self._window_agg( window_func=WindowFunction.EXPANDING, agg_func="sem", @@ -12986,6 +13274,7 @@ def replace( ) return SnowflakeQueryCompiler(result.frame) + @snowpark_pandas_type_immutable_check def add_substring( self, substring: str, @@ -13065,6 +13354,8 @@ def add_substring( data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=frame.cached_data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) if axis == 0: @@ -13218,6 +13509,8 @@ def duplicated( index_column_snowflake_quoted_identifiers=[ row_position_post_dedup.row_position_snowflake_quoted_identifier ], + data_column_types=[None], + index_column_types=[None], ) joined_ordered_dataframe = join_utils.join( @@ -13250,6 +13543,8 @@ def duplicated( data_column_pandas_index_names=frame.data_column_pandas_index_names, index_column_pandas_labels=frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=[None], + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_frame) @@ -13305,6 +13600,7 @@ def _binary_op_between_dataframe_and_series_along_axis_0( Returns: SnowflakeQueryCompiler representing result of binary op operation. """ + self._raise_not_implemented_error_for_timedelta() assert ( other.is_series_like() @@ -13412,6 +13708,7 @@ def create_lazy_type_functions( return SnowflakeQueryCompiler(new_frame) + @snowpark_pandas_type_immutable_check def round( self, decimals: Union[int, Mapping, "pd.Series"] = 0, **kwargs: Any ) -> "SnowflakeQueryCompiler": @@ -13482,6 +13779,8 @@ def idxmax( Returns: SnowflakeQueryCompiler """ + self._raise_not_implemented_error_for_timedelta() + return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -13506,6 +13805,8 @@ def idxmin( Returns: SnowflakeQueryCompiler """ + self._raise_not_implemented_error_for_timedelta() + return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -13647,6 +13948,8 @@ def infer_sorted_column_labels( data_column_snowflake_quoted_identifiers=updated_data_identifiers, index_column_pandas_labels=new_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(result_frame) @@ -13816,6 +14119,8 @@ def infer_sorted_column_labels( + new_identifiers, index_column_pandas_labels=index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) # Replace all columns with NULL literals. new_frame = new_frame.update_snowflake_quoted_identifiers_with_expressions( @@ -13860,6 +14165,8 @@ def infer_sorted_column_labels( data_column_snowflake_quoted_identifiers=expanded_data_column_snowflake_quoted_identifiers, index_column_pandas_labels=index_column_pandas_labels, index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) # For columns that exist in both self and other, update the corresponding identifier with the result @@ -14970,6 +15277,7 @@ def output_col(col_name: ColumnOrName) -> SnowparkColumn: ) return SnowflakeQueryCompiler(new_internal_frame) + @snowpark_pandas_type_immutable_check def str_strip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompiler": """ Remove leading and trailing characters. @@ -14989,6 +15297,7 @@ def str_strip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompile sp_func=trim, pd_func_name="strip", to_strip=to_strip ) + @snowpark_pandas_type_immutable_check def str_lstrip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompiler": """ Remove leading characters. @@ -15008,6 +15317,7 @@ def str_lstrip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompil sp_func=ltrim, pd_func_name="lstrip", to_strip=to_strip ) + @snowpark_pandas_type_immutable_check def str_rstrip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompiler": """ Remove trailing characters. @@ -15030,6 +15340,7 @@ def str_rstrip(self, to_strip: Union[str, None] = None) -> "SnowflakeQueryCompil def str_swapcase(self) -> None: ErrorMessage.method_not_implemented_error("swapcase", "Series.str") + @snowpark_pandas_type_immutable_check def str_translate(self, table: dict) -> "SnowflakeQueryCompiler": """ Map all characters in the string through the given mapping table. @@ -15206,6 +15517,8 @@ def qcut( data_column_snowflake_quoted_identifiers=[new_data_identifier], index_column_pandas_labels=self._modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers, + data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types, + index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_frame) @@ -15240,6 +15553,8 @@ def _groupby_head_tail( Returns: A SnowflakeQueryCompiler object representing a DataFrame. """ + self._raise_not_implemented_error_for_timedelta() + original_frame = self._modin_frame ordered_dataframe = original_frame.ordered_dataframe @@ -15386,6 +15701,8 @@ def _groupby_head_tail( data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=original_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(new_modin_frame) @@ -15664,6 +15981,10 @@ def dt_ceil( ], index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ + -1: + ], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -15859,6 +16180,10 @@ def dt_floor( ], index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ + -1: + ], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -15915,6 +16240,10 @@ def dt_month_name(self, locale: Optional[str] = None) -> "SnowflakeQueryCompiler ], index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ + -1: + ], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -15959,6 +16288,10 @@ def dt_day_name(self, locale: Optional[str] = None) -> "SnowflakeQueryCompiler": ], index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ + -1: + ], + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -16067,6 +16400,8 @@ def topn( data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, index_column_pandas_labels=internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, + data_column_types=internal_frame.cached_data_column_snowpark_pandas_types, + index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, ) ) @@ -16296,6 +16631,8 @@ def equals( index_column_pandas_labels=self_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=other_frame.index_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self_frame.data_column_pandas_index_names, + data_column_types=other_frame.cached_data_column_snowpark_pandas_types, + index_column_types=other_frame.cached_index_column_snowpark_pandas_types, ) # Align (join) both dataframes on index. @@ -16320,6 +16657,10 @@ def equals( updated_result.old_id_to_new_id_mappings[p.identifier] for p in left_right_pairs ] + updated_data_identifiers_types = [ + updated_result.frame.snowflake_quoted_identifier_to_snowpark_pandas_type[id] + for id in updated_data_identifiers + ] new_frame = updated_result.frame result_frame = InternalFrame.create( ordered_dataframe=new_frame.ordered_dataframe, @@ -16328,6 +16669,8 @@ def equals( data_column_snowflake_quoted_identifiers=updated_data_identifiers, index_column_pandas_labels=new_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_frame.index_column_snowflake_quoted_identifiers, + data_column_types=updated_data_identifiers_types, + index_column_types=new_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(result_frame) @@ -16654,6 +16997,8 @@ def corr( data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_quoted_identifier], + data_column_types=None, + index_column_types=[None], ) query_compilers.append(SnowflakeQueryCompiler(new_frame)) diff --git a/src/snowflake/snowpark/modin/plugin/utils/error_message.py b/src/snowflake/snowpark/modin/plugin/utils/error_message.py index d2547e8a4f7..997af701f2b 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/error_message.py +++ b/src/snowflake/snowpark/modin/plugin/utils/error_message.py @@ -158,6 +158,12 @@ def not_implemented(cls, message: str) -> NoReturn: # pragma: no cover logger.debug(f"NotImplementedError: {message}") raise NotImplementedError(message) + @classmethod + def not_implemented_for_timedelta(cls, method: str) -> NoReturn: + ErrorMessage.not_implemented( + f"SnowflakeQueryCompiler::{method} is not yet implemented for Timedelta Type" + ) + @staticmethod def method_not_implemented_error( name: str, class_: str diff --git a/tests/integ/modin/frame/test_reset_index.py b/tests/integ/modin/frame/test_reset_index.py index 62c1b17ff96..6d36d4cfc4a 100644 --- a/tests/integ/modin/frame/test_reset_index.py +++ b/tests/integ/modin/frame/test_reset_index.py @@ -20,6 +20,11 @@ def native_df_simple(): { "a": ["one", "two", "three"], "b": ["abc", "pqr", "xyz"], + "dt": [ + native_pd.Timedelta("1 days"), + native_pd.Timedelta("2 days"), + native_pd.Timedelta("3 days"), + ], }, index=native_pd.Index(["a", "b", "c"], name="c"), ) @@ -74,13 +79,13 @@ def test_reset_index_drop_false(native_df_simple): eval_snowpark_pandas_result(snow_df, native_df_simple, lambda df: df.reset_index()) snow_df = snow_df.reset_index() - assert ["c", "a", "b"] == list(snow_df.columns) + assert ["c", "a", "b", "dt"] == list(snow_df.columns) snow_df = snow_df.reset_index() - assert ["index", "c", "a", "b"] == list(snow_df.columns) + assert ["index", "c", "a", "b", "dt"] == list(snow_df.columns) snow_df = snow_df.reset_index() - assert ["level_0", "index", "c", "a", "b"] == list(snow_df.columns) + assert ["level_0", "index", "c", "a", "b", "dt"] == list(snow_df.columns) @sql_count_checker(query_count=1) @@ -168,7 +173,7 @@ def test_reset_index_allow_duplicates(native_df_simple): # Allow duplicates when provided name conflicts with existing data label. snow_df = pd.DataFrame(native_df_simple) snow_df = snow_df.reset_index(drop=False, allow_duplicates=True, names=["a"]) - assert ["a", "a", "b"] == list(snow_df.columns) + assert ["a", "a", "b", "dt"] == list(snow_df.columns) # Verify even if allow_duplicates is True, "index" is not duplicated. snow_df = pd.DataFrame({"index": ["one", "two", "three"]}) diff --git a/tests/integ/modin/series/test_cache_result.py b/tests/integ/modin/series/test_cache_result.py index 5a6b74ad95c..bba55c01ac9 100644 --- a/tests/integ/modin/series/test_cache_result.py +++ b/tests/integ/modin/series/test_cache_result.py @@ -12,7 +12,7 @@ from pandas.testing import assert_series_equal import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.sql_counter import SqlCounter +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import ( assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_series, @@ -149,3 +149,15 @@ def test_cache_result_post_apply(self, inplace, simple_test_data): cached_snow_series, native_series, ) + + +@sql_count_checker(query_count=1) +def test_cacheresult_timedelta(): + native_s = native_pd.Series( + [ + native_pd.Timedelta("1 days"), + native_pd.Timedelta("2 days"), + native_pd.Timedelta("3 days"), + ] + ) + assert "timedelta64[ns]" == pd.Series(native_s).cache_result().dtype diff --git a/tests/integ/modin/series/test_copy.py b/tests/integ/modin/series/test_copy.py index 3b8efe25207..b2dcb9f927d 100644 --- a/tests/integ/modin/series/test_copy.py +++ b/tests/integ/modin/series/test_copy.py @@ -2,11 +2,15 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import modin.pandas as pd +import pandas as native_pd import pytest import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker -from tests.integ.modin.utils import assert_snowpark_pandas_equal_to_pandas +from tests.integ.modin.utils import ( + assert_snowpark_pandas_equal_to_pandas, + eval_snowpark_pandas_result, +) @pytest.fixture(scope="function") @@ -77,3 +81,15 @@ def test_copy_inplace_operations_on_shallow_copy(snow_series, operation): # Verify that 'snow_series' is also changed. assert_snowpark_pandas_equal_to_pandas(snow_series, copy.to_pandas()) + + +@sql_count_checker(query_count=1) +def test_copy_timedelta(): + native_s = native_pd.Series( + [ + native_pd.Timedelta("1 days"), + native_pd.Timedelta("2 days"), + native_pd.Timedelta("3 days"), + ] + ) + eval_snowpark_pandas_result(pd.Series(native_s), native_s, lambda s: s.copy()) diff --git a/tests/integ/modin/series/test_shift.py b/tests/integ/modin/series/test_shift.py index aee56bce775..7f27c4d313b 100644 --- a/tests/integ/modin/series/test_shift.py +++ b/tests/integ/modin/series/test_shift.py @@ -17,6 +17,13 @@ native_pd.Series([1, 2, 3, 4]), native_pd.Series(["a", None, 1, 2, 4.5]), native_pd.Series([2.0, None, 3.6, -10], index=[1, 2, 3, 4]), + native_pd.Series( + [ + native_pd.Timedelta("1 days"), + native_pd.Timedelta("2 days"), + native_pd.Timedelta("3 days"), + ], + ), ] @@ -36,7 +43,16 @@ def test_series_with_values_shift(series, periods, fill_value): eval_snowpark_pandas_result( snow_series, native_series, - lambda s: s.shift(periods=periods, fill_value=fill_value), + lambda s: s.shift( + periods=periods, + fill_value=pd.Timedelta(fill_value) + if isinstance( + s, native_pd.Series + ) # pandas does not support fill int to timedelta + and s.dtype == "timedelta64[ns]" + and fill_value is not no_default + else fill_value, + ), ) diff --git a/tests/integ/modin/series/test_sort_index.py b/tests/integ/modin/series/test_sort_index.py index 9efc2fb4d84..2a63da083f4 100644 --- a/tests/integ/modin/series/test_sort_index.py +++ b/tests/integ/modin/series/test_sort_index.py @@ -15,9 +15,21 @@ @pytest.mark.parametrize("na_position", ["first", "last"]) @pytest.mark.parametrize("ignore_index", [True, False]) @pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize( + "data", + [ + ["a", "b", np.nan, "d"], + [ + native_pd.Timedelta("1 days"), + native_pd.Timedelta("2 days"), + native_pd.Timedelta("3 days"), + native_pd.Timedelta(None), + ], + ], +) @sql_count_checker(query_count=1) -def test_sort_index_series(ascending, na_position, ignore_index, inplace): - native_series = native_pd.Series(["a", "b", np.nan, "d"], index=[3, 2, 1, np.nan]) +def test_sort_index_series(ascending, na_position, ignore_index, inplace, data): + native_series = native_pd.Series(data, index=[3, 2, 1, np.nan]) snow_series = pd.Series(native_series) eval_snowpark_pandas_result( snow_series, diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index 37447db57ca..f0d8440009f 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -94,3 +94,23 @@ def test_timedelta_precision_insufficient_with_nulls_SNOW_1628925(): eval_snowpark_pandas_result( pd, native_pd, lambda lib: lib.Series([None, timedelta]) ) + + +@sql_count_checker(query_count=0) +def test_timedelta_not_supported(): + df = pd.DataFrame( + { + "a": ["one", "two", "three"], + "b": ["abc", "pqr", "xyz"], + "dt": [ + pd.Timedelta("1 days"), + pd.Timedelta("2 days"), + pd.Timedelta("3 days"), + ], + } + ) + with pytest.raises( + NotImplementedError, + match="validate_groupby is not yet implemented for Timedelta Type", + ): + df.groupby("a").count() diff --git a/tests/unit/modin/conftest.py b/tests/unit/modin/conftest.py index 256c0e7aa3f..863078f14e2 100644 --- a/tests/unit/modin/conftest.py +++ b/tests/unit/modin/conftest.py @@ -20,6 +20,9 @@ def mock_single_col_query_compiler() -> SnowflakeQueryCompiler: mock_internal_frame = mock.create_autospec(InternalFrame) mock_internal_frame.data_columns_index = native_pd.Index(["A"], name="B") mock_internal_frame.data_column_snowflake_quoted_identifiers = ['"A"'] + mock_internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type = { + '"A"': None + } fake_query_compiler = SnowflakeQueryCompiler(mock_internal_frame) return fake_query_compiler diff --git a/tests/unit/modin/test_internal_frame.py b/tests/unit/modin/test_internal_frame.py index 15cc6105e3f..cd5e57eb5a1 100644 --- a/tests/unit/modin/test_internal_frame.py +++ b/tests/unit/modin/test_internal_frame.py @@ -73,6 +73,8 @@ def test_dataframes(mock_dataframe) -> TestDataFrames: data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"', '"d"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) return TestDataFrames(ordered_dataframe, internal_frame) @@ -110,6 +112,8 @@ def test_dataframes_with_multiindex_on_column(mock_dataframe) -> TestDataFrames: data_column_snowflake_quoted_identifiers=["\"('a', 'C')\"", "\"('b', 'D')\""], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) return TestDataFrames(ordered_dataframe, internal_frame) @@ -125,6 +129,8 @@ def test_snowflake_quoted_identifier_without_quote_negative(test_dataframes) -> data_column_snowflake_quoted_identifiers=["a", "b", "c"], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert "Found not-quoted identifier for 'dataframe column':'a'" in str(exc.value) @@ -143,6 +149,8 @@ def test_column_labels_and_quoted_identifiers_have_same_length_negative( data_column_snowflake_quoted_identifiers=['"a"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) # check index columns @@ -154,6 +162,8 @@ def test_column_labels_and_quoted_identifiers_have_same_length_negative( data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"'], index_column_pandas_labels=[], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) @@ -167,6 +177,8 @@ def test_internal_frame_missing_data_column_negative(test_dataframes): data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"D"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert 'dataframe column="D" not found in snowpark dataframe schema' in str( @@ -184,6 +196,8 @@ def test_internal_frame_missing_index_column_negative(test_dataframes): data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"E"'], + data_column_types=None, + index_column_types=None, ) assert 'dataframe column="E" not found in snowpark dataframe schema' in str( @@ -230,6 +244,8 @@ def test_pandas_label_as_empty_and_none(test_dataframes) -> None: data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.data_column_pandas_labels == ["", "b", None] @@ -286,6 +302,8 @@ def test_internal_frame_ordering_columns(test_dataframes) -> None: data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.ordering_column_snowflake_quoted_identifiers == [ @@ -306,6 +324,8 @@ def test_internal_frame_ordering_columns(test_dataframes) -> None: data_column_snowflake_quoted_identifiers=['"a"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.ordering_column_snowflake_quoted_identifiers == [ @@ -322,6 +342,8 @@ def test_internal_frame_ordering_columns(test_dataframes) -> None: data_column_snowflake_quoted_identifiers=['"a"', '"C"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.ordering_column_snowflake_quoted_identifiers == [ @@ -344,6 +366,8 @@ def test_data_column_pandas_index_names(pandas_label, test_dataframes) -> None: data_column_snowflake_quoted_identifiers=['"a"', '"C"'], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.data_column_pandas_index_names == [pandas_label] @@ -417,6 +441,8 @@ def test_data_column_pandas_multiindex_negative( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) @@ -443,6 +469,8 @@ def test_get_snowflake_quoted_identifiers_by_pandas_labels_empty_not_include_ind data_column_snowflake_quoted_identifiers=['"a"', '"b"', '"C"'], index_column_pandas_labels=["index"], index_column_snowflake_quoted_identifiers=['"INDEX"'], + data_column_types=None, + index_column_types=None, ) assert internal_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( @@ -625,6 +653,8 @@ def test_num_levels(mock_dataframe, level0, level1): data_column_snowflake_quoted_identifiers=['"x"', '"y"'], index_column_pandas_labels=[None] * level0, index_column_snowflake_quoted_identifiers=['"a"', '"b"'][:level0], + data_column_types=None, + index_column_types=None, ) assert frame.num_index_levels(axis=0) == level0 assert frame.num_index_levels(axis=1) == level1 @@ -704,6 +734,8 @@ def test_validation_duplicated_data_columns_for_labels( ], index_column_pandas_labels=["F"], index_column_snowflake_quoted_identifiers=['"F_INDEX"'], + data_column_types=None, + index_column_types=None, ) if expected_message is not None: diff --git a/tests/unit/modin/test_snowflake_query_compiler.py b/tests/unit/modin/test_snowflake_query_compiler.py index 90a2a5f3f1a..48775a47914 100644 --- a/tests/unit/modin/test_snowflake_query_compiler.py +++ b/tests/unit/modin/test_snowflake_query_compiler.py @@ -41,6 +41,8 @@ def test_query_compiler(mock_dataframe) -> SnowflakeQueryCompiler: data_column_snowflake_quoted_identifiers=['"a"', '"B"'], index_column_pandas_labels=["INDEX", "C"], index_column_snowflake_quoted_identifiers=['"INDEX"', '"C"'], + data_column_types=None, + index_column_types=None, ) return SnowflakeQueryCompiler(internal_frame)