diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e8de24f451..81aadb14d47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,7 @@ - support for `GroupBy` filtrations `first` and `last`. - support for `TimedeltaIndex` attributes: `days`, `seconds`, `microseconds` and `nanoseconds`. - support for `diff` with timestamp columns on `axis=0` and `axis=1` + - support for `TimedeltaIndex` methods: `ceil`, `floor` and `round`. - Added support for index's arithmetic and comparison operators. - Added support for `Series.dt.round`. - Added documentation pages for `DatetimeIndex`. diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst index cd5e64b8c98..e11807ae14a 100644 --- a/docs/source/modin/supported/timedelta_index_supported.rst +++ b/docs/source/modin/supported/timedelta_index_supported.rst @@ -38,11 +38,11 @@ Methods +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``to_pytimedelta`` | N | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ -| ``round`` | N | | | +| ``round`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ -| ``floor`` | N | | | +| ``floor`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ -| ``ceil`` | N | | | +| ``ceil`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``mean`` | N | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index c4ed377c05c..db8f85154d6 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -1299,6 +1299,7 @@ def apply_snowpark_function_to_columns( self, snowpark_func: Callable[[Any], SnowparkColumn], include_index: bool = False, + return_type: Optional[SnowparkPandasType] = None, ) -> "InternalFrame": """ Apply snowpark function callable to all data columns of an InternalFrame. If @@ -1307,6 +1308,7 @@ def apply_snowpark_function_to_columns( Arguments: snowpark_func: Snowpark function to apply to columns of underlying snowpark df. + return_type: The optional SnowparkPandasType for the new column. include_index: Whether to apply the function to index columns as well. Returns: @@ -1317,7 +1319,8 @@ def apply_snowpark_function_to_columns( snowflake_ids.extend(self.index_column_snowflake_quoted_identifiers) return self.update_snowflake_quoted_identifiers_with_expressions( - {col_id: snowpark_func(col(col_id)) for col_id in snowflake_ids} + {col_id: snowpark_func(col(col_id)) for col_id in snowflake_ids}, + [return_type] * len(snowflake_ids) if return_type else None, ).frame def select_active_columns(self) -> "InternalFrame": diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index c4873724789..0242177d1f0 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -11,7 +11,8 @@ import numpy as np import pandas as native_pd from pandas._libs import lib -from pandas._typing import DateTimeErrorChoices +from pandas._libs.tslibs import to_offset +from pandas._typing import DateTimeErrorChoices, Frequency from pandas.api.types import is_datetime64_any_dtype, is_float_dtype, is_integer_dtype from snowflake.snowpark import Column @@ -168,11 +169,24 @@ def col_to_s(col: Column, unit: Literal["D", "s", "ms", "us", "ns"]) -> Column: return col / 10**9 +def timedelta_freq_to_nanos(freq: Frequency) -> int: + """ + Convert a pandas frequency string to nanoseconds. + + Args: + freq: Timedelta frequency string or offset. + + Returns: + int: nanoseconds + """ + return to_offset(freq).nanos + + def col_to_timedelta(col: Column, unit: str) -> Column: """ Converts ``col`` (stored in the specified units) to timedelta nanoseconds. """ - td_unit = VALID_PANDAS_TIMEDELTA_ABBREVS.get(unit) + td_unit = VALID_PANDAS_TIMEDELTA_ABBREVS.get(unit.lower()) if not td_unit: # Same error as native pandas. raise ValueError(f"invalid unit abbreviation: {unit}") diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index ac220f91ecf..cff08a0f46d 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -21,7 +21,6 @@ import pandas.io.parsers import pandas.io.parsers.readers from modin.core.storage_formats import BaseQueryCompiler # type: ignore -from numpy import dtype from pandas import Timedelta from pandas._libs import lib from pandas._libs.lib import no_default @@ -80,6 +79,7 @@ from snowflake.snowpark.functions import ( abs as abs_, array_construct, + bround, builtin, cast, coalesce, @@ -279,6 +279,7 @@ col_to_timedelta, generate_timestamp_col, raise_if_to_datetime_not_supported, + timedelta_freq_to_nanos, to_snowflake_timestamp_format, ) from snowflake.snowpark.modin.plugin._internal.transpose_utils import ( @@ -504,7 +505,7 @@ def dtypes(self) -> native_pd.Series: ) @property - def index_dtypes(self) -> list[Union[dtype, ExtensionDtype]]: + def index_dtypes(self) -> list[Union[np.dtype, ExtensionDtype]]: """ Get index dtypes. @@ -9126,7 +9127,7 @@ def invert(self) -> "SnowflakeQueryCompiler": def astype( self, - col_dtypes_map: dict[str, Union[dtype, ExtensionDtype]], + col_dtypes_map: dict[str, Union[np.dtype, ExtensionDtype]], errors: Literal["raise", "ignore"] = "raise", ) -> "SnowflakeQueryCompiler": """ @@ -9201,7 +9202,7 @@ def astype( def astype_index( self, - col_dtypes_map: dict[Hashable, Union[dtype, ExtensionDtype]], + col_dtypes_map: dict[Hashable, Union[np.dtype, ExtensionDtype]], ) -> "SnowflakeQueryCompiler": """ Convert index columns dtypes to given dtypes. @@ -16511,23 +16512,39 @@ def dt_ceil( if nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) - slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( - rule=freq # type: ignore[arg-type] - ) + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if is_datetime64_any_dtype(dtype): + slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( + rule=freq # type: ignore[arg-type] + ) - if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: - ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name) + if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: + ErrorMessage.parameter_not_implemented_error( + f"freq='{freq}'", method_name + ) + return_type = None - def ceil_func(column: SnowparkColumn) -> SnowparkColumn: - floor_column = builtin("time_slice")( - column, slice_length, slice_unit, "START" - ) - ceil_column = builtin("time_slice")(column, slice_length, slice_unit, "END") - return iff(column.equal_null(floor_column), column, ceil_column) + def ceil_func(column: SnowparkColumn) -> SnowparkColumn: + floor_column = builtin("time_slice")( + column, slice_length, slice_unit, "START" + ) + ceil_column = builtin("time_slice")( + column, slice_length, slice_unit, "END" + ) + return iff(column.equal_null(floor_column), column, ceil_column) + + else: # timedelta type + nanos = timedelta_freq_to_nanos(freq) + return_type = TimedeltaType() + + def ceil_func(column: SnowparkColumn) -> SnowparkColumn: + return iff( + column % nanos == 0, column, column + nanos - (column % nanos) + ) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - ceil_func, include_index + ceil_func, include_index, return_type ) ) @@ -16565,94 +16582,115 @@ def dt_round( if nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) - slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( - rule=freq # type: ignore[arg-type] - ) - - if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS or slice_unit == "second": - ErrorMessage.parameter_not_implemented_error(f"freq={freq}", method_name) - - # We need to implement the algorithm for rounding half to even whenever - # the date value is at half point of the slice: - # https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even - - # First, we need to calculate the length of half a slice. - # This is straightforward if the length is already even. - # If not, we then need to first downlevel the freq to a - # lower granularity to ensure that it is even. + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if is_datetime64_any_dtype(dtype): + slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( + rule=freq # type: ignore[arg-type] + ) - def down_level_freq(slice_length: int, slice_unit: str) -> tuple[int, str]: - if slice_unit == "minute": - slice_length *= 60 - slice_unit = "second" - elif slice_unit == "hour": - slice_length *= 60 - slice_unit = "minute" - elif slice_unit == "day": - slice_length *= 24 - slice_unit = "hour" - else: - f"Snowpark pandas 'Series.dt.round' method doesn't support setting 'freq' parameter with '{slice_unit}' unit" - return slice_length, slice_unit + if ( + slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS + or slice_unit == "second" + ): + ErrorMessage.parameter_not_implemented_error( + f"freq={freq}", method_name + ) - if slice_length % 2 == 1: - slice_length, slice_unit = down_level_freq(slice_length, slice_unit) - half_slice_length = int(slice_length / 2) + # We need to implement the algorithm for rounding half to even whenever + # the date value is at half point of the slice: + # https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + + # First, we need to calculate the length of half a slice. + # This is straightforward if the length is already even. + # If not, we then need to first downlevel the freq to a + # lower granularity to ensure that it is even. + + def down_level_freq(slice_length: int, slice_unit: str) -> tuple[int, str]: + if slice_unit == "minute": + slice_length *= 60 + slice_unit = "second" + elif slice_unit == "hour": + slice_length *= 60 + slice_unit = "minute" + elif slice_unit == "day": + slice_length *= 24 + slice_unit = "hour" + else: + f"Snowpark pandas 'Series.dt.round' method doesn't support setting 'freq' parameter with '{slice_unit}' unit" + return slice_length, slice_unit - def slice_length_when_unit_is_second(slice_length: int, slice_unit: str) -> int: - while slice_unit != "second": + if slice_length % 2 == 1: slice_length, slice_unit = down_level_freq(slice_length, slice_unit) - return slice_length + half_slice_length = int(slice_length / 2) + return_type = None + + def slice_length_when_unit_is_second( + slice_length: int, slice_unit: str + ) -> int: + while slice_unit != "second": + slice_length, slice_unit = down_level_freq(slice_length, slice_unit) + return slice_length + + def round_func(column: SnowparkColumn) -> SnowparkColumn: + # Second, we determine whether floor represents an even number of slices. + # To do so, we must divide the number of epoch seconds in it over the number + # of epoch seconds in one slice. This way, we can get the number of slices. + + floor_column = builtin("time_slice")( + column, slice_length, slice_unit, "START" + ) + ceil_column = builtin("time_slice")( + column, slice_length, slice_unit, "END" + ) - def round_func(column: SnowparkColumn) -> SnowparkColumn: - # Second, we determine whether floor represents an even number of slices. - # To do so, we must divide the number of epoch seconds in it over the number - # of epoch seconds in one slice. This way, we can get the number of slices. + floor_epoch_seconds_column = builtin("extract")( + "epoch_second", floor_column + ) + floor_num_slices_column = cast( + floor_epoch_seconds_column + / pandas_lit( + slice_length_when_unit_is_second(slice_length, slice_unit) + ), + IntegerType(), + ) - floor_column = builtin("time_slice")( - column, slice_length, slice_unit, "START" - ) - ceil_column = builtin("time_slice")(column, slice_length, slice_unit, "END") + # Now that we know the number of slices, we can check if they are even or odd. + floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null( + pandas_lit(0) + ) - floor_epoch_seconds_column = builtin("extract")( - "epoch_second", floor_column - ) - floor_num_slices_column = cast( - floor_epoch_seconds_column - / pandas_lit( - slice_length_when_unit_is_second(slice_length, slice_unit) - ), - IntegerType(), - ) + # Accordingly, we can decide if the round column should be the floor or ceil + # of the slice. + round_column_if_half_point = iff( + floor_is_even, floor_column, ceil_column + ) - # Now that we know the number of slices, we can check if they are even or odd. - floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null( - pandas_lit(0) - ) + # In case the date value is not at half point of the slice, then we shift it + # by half a slice, and take the floor from there. + base_plus_half_slice_column = dateadd( + slice_unit, pandas_lit(half_slice_length), column + ) + round_column_if_not_half_point = builtin("time_slice")( + base_plus_half_slice_column, slice_length, slice_unit, "START" + ) - # Accordingly, we can decide if the round column should be the floor or ceil - # of the slice. - round_column_if_half_point = iff(floor_is_even, floor_column, ceil_column) + # The final expression for the round column. + return iff( + base_plus_half_slice_column.equal_null(ceil_column), + round_column_if_half_point, + round_column_if_not_half_point, + ) - # In case the date value is not at half point of the slice, then we shift it - # by half a slice, and take the floor from there. - base_plus_half_slice_column = dateadd( - slice_unit, pandas_lit(half_slice_length), column - ) - round_column_if_not_half_point = builtin("time_slice")( - base_plus_half_slice_column, slice_length, slice_unit, "START" - ) + else: # timedelta type + nanos = timedelta_freq_to_nanos(freq) + return_type = TimedeltaType() - # The final expression for the round column. - return iff( - base_plus_half_slice_column.equal_null(ceil_column), - round_column_if_half_point, - round_column_if_not_half_point, - ) + def round_func(column: SnowparkColumn) -> SnowparkColumn: + return bround(column / nanos, 0) * nanos return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - round_func, include_index + round_func, include_index, return_type ) ) @@ -16683,26 +16721,39 @@ def dt_floor( Returns: A new QueryCompiler with floor values. """ + # This method should support both datetime and timedelta types. method_name = "DatetimeIndex.floor" if include_index else "Series.dt.floor" if ambiguous != "raise": ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) - slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( - rule=freq # type: ignore[arg-type] - ) + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if is_datetime64_any_dtype(dtype): + slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( + rule=freq # type: ignore[arg-type] + ) - if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: - ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name) + if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: + ErrorMessage.parameter_not_implemented_error( + f"freq='{freq}'", method_name + ) + return_type = None + + def floor_func(column: SnowparkColumn) -> SnowparkColumn: + return builtin("time_slice")(column, slice_length, slice_unit) + + else: # timedelta type + nanos = timedelta_freq_to_nanos(freq) + return_type = TimedeltaType() - def floor_func(column: SnowparkColumn) -> SnowparkColumn: - return builtin("time_slice")(column, slice_length, slice_unit) + def floor_func(column: SnowparkColumn) -> SnowparkColumn: + return column - (column % nanos) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - floor_func, include_index - ) + floor_func, include_index, return_type + ), ) def dt_normalize(self, include_index: bool = False) -> "SnowflakeQueryCompiler": diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 5e8dc016d43..4488e09bb36 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -303,7 +303,6 @@ def inferred_freq(self) -> str | None: '10D' """ - @timedelta_index_not_implemented() def round(self, freq: Frequency) -> TimedeltaIndex: """ Perform round operation on the data to the specified `freq`. @@ -323,8 +322,10 @@ def round(self, freq: Frequency) -> TimedeltaIndex: ------ ValueError if the `freq` cannot be converted. """ + return TimedeltaIndex( + query_compiler=self._query_compiler.dt_round(freq, include_index=True) + ) - @timedelta_index_not_implemented() def floor(self, freq: Frequency) -> TimedeltaIndex: """ Perform floor operation on the data to the specified `freq`. @@ -344,8 +345,10 @@ def floor(self, freq: Frequency) -> TimedeltaIndex: ------ ValueError if the `freq` cannot be converted. """ + return TimedeltaIndex( + query_compiler=self._query_compiler.dt_floor(freq, include_index=True) + ) - @timedelta_index_not_implemented() def ceil(self, freq: Frequency) -> TimedeltaIndex: """ Perform ceil operation on the data to the specified `freq`. @@ -365,6 +368,9 @@ def ceil(self, freq: Frequency) -> TimedeltaIndex: ------ ValueError if the `freq` cannot be converted. """ + return TimedeltaIndex( + query_compiler=self._query_compiler.dt_ceil(freq, include_index=True) + ) @timedelta_index_not_implemented() def to_pytimedelta(self) -> np.ndarray: diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py index 449ad576ef7..d4f82f62323 100644 --- a/tests/integ/modin/index/test_timedelta_index_methods.py +++ b/tests/integ/modin/index/test_timedelta_index_methods.py @@ -8,7 +8,7 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_index_equal +from tests.integ.modin.utils import assert_index_equal, eval_snowpark_pandas_result @sql_count_checker(query_count=0) @@ -74,3 +74,51 @@ def test_timedelta_index_properties(attr): assert_index_equal( getattr(snow_index, attr), getattr(native_index, attr), exact=False ) + + +@pytest.mark.parametrize("method", ["round", "floor", "ceil"]) +@pytest.mark.parametrize("freq", ["ns", "us", "ms", "s", "min", "h", "d"]) +@sql_count_checker(query_count=1) +def test_timedelta_floor_ceil_round(method, freq): + native_index = native_pd.TimedeltaIndex( + [ + "1d", + "1h", + "5h", + "9h", + "60s", + "1s", + "800ms", + "900ms", + "5us", + "6ns", + "1ns", + "1d 3s", + "9m 15s 8us", + None, + ] + ) + snow_index = pd.Index(native_index) + eval_snowpark_pandas_result( + snow_index, native_index, lambda x: getattr(x, method)(freq) + ) + + +@pytest.mark.parametrize("method", ["round", "floor", "ceil"]) +@pytest.mark.parametrize( + "freq", ["nano", "millis", "second", "minutes", "hour", "days", "month", "year"] +) +@sql_count_checker(query_count=0) +def test_timedelta_floor_ceil_round_negative(method, freq): + native_index = native_pd.TimedeltaIndex( + ["1d", "5h", "60s", "1s", "900ms", "5us", "1ns", "1d 3s", "9m 15s 8us", None] + ) + snow_index = pd.Index(native_index) + eval_snowpark_pandas_result( + snow_index, + native_index, + lambda x: getattr(x, method)(freq), + expect_exception=True, + expect_exception_type=ValueError, + expect_exception_match=f"Invalid frequency: {freq}", + )