From 5a37b2786d3182cb3a3a43c8e2360cf91e5c9957 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:31:58 -0800 Subject: [PATCH 1/5] ENH: Implement Series.interpolate for ArrowDtype --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 37 +++++++++++++++++++++++++++- pandas/tests/extension/test_arrow.py | 35 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index b092cf6f81ef2..ae497e7158aa2 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -227,6 +227,7 @@ Other enhancements - Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`) - DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`) +- Implemented :meth:`Series.interpolate` for :class:`ArrowDtype` (:issue:`56267`) - Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as "BMS" (:issue:`56243`) - Improved error message when constructing :class:`Period` with invalid offsets such as "QS" (:issue:`55785`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index d162b66e5d369..f13bb45fc7144 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -150,6 +150,7 @@ def floordiv_compat( AxisInt, Dtype, FillnaOptions, + InterpolateOptions, Iterator, NpDtype, NumpySorter, @@ -164,7 +165,10 @@ def floordiv_compat( npt, ) - from pandas import Series + from pandas import ( + Index, + Series, + ) from pandas.core.arrays.datetimes import DatetimeArray from pandas.core.arrays.timedeltas import TimedeltaArray @@ -1831,6 +1835,37 @@ def _rank_calc( return result + def interpolate( + self, + *, + method: InterpolateOptions, + axis: int, + index: Index, + limit, + limit_direction, + limit_area, + copy: bool, + **kwargs, + ) -> Self: + if method != "linear": + raise NotImplementedError("Only method='linear' is implemented.") + if limit_area is not None: + raise NotImplementedError("Only limit_area=None is implemented.") + if limit is not None: + raise NotImplementedError("Only limit=0 is implemented.") + if limit_direction != "forward": + raise NotImplementedError("Only limit_direction='forward' is implemented.") + + if not self.dtype._is_numeric: + raise ValueError("Values must be numeric.") + + values = self._pa_array.combine_chunks() + na_value = pa.array([None], type=values.type) + y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2)) + prev_values = pa.concat_arrays([na_value, values[:-2], na_value]) + interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2)) + return type(self)(pc.coalesce(self._pa_array, interps)) + def _rank( self, *, diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 7131a50956a7d..a119045636867 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2959,3 +2959,38 @@ def test_arrow_floordiv(): expected = pd.Series([-2], dtype="int64[pyarrow]") result = a // b tm.assert_series_equal(result, expected) + + +def test_interpolate_not_numeric(data): + if not data.dtype._is_numeric: + with pytest.raises(ValueError, match="Values must be numeric."): + pd.Series(data).interpolate() + + +def test_interpolate_not_supported(): + ser = pd.Series([1, None], dtype="int64[pyarrow]") + with pytest.raises( + NotImplementedError, match="Only method='linear' is implemented." + ): + ser.interpolate(method="akima") + + with pytest.raises( + NotImplementedError, match="Only limit_area=None is implemented." + ): + ser.interpolate(limit_area="inside") + + with pytest.raises(NotImplementedError, match="Only limit=0 is implemented."): + ser.interpolate(limit=1) + + with pytest.raises( + NotImplementedError, match="Only limit_direction='forward' is implemented." + ): + ser.interpolate(limit_direction="backward") + + +@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"]) +def test_interpolate(dtype): + ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype) + result = ser.interpolate() + expected = pd.Series([None, 1, 2, 3, 4, None], dtype=dtype) + tm.assert_series_equal(result, expected) From 44c4c9ba24e419773c070bbe7dbce73facd9f414 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 5 Dec 2023 16:45:51 -0800 Subject: [PATCH 2/5] Min version compat --- pandas/core/arrays/arrow/array.py | 3 +++ pandas/tests/extension/test_arrow.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index e95ee35d23839..4f14af9b24200 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1848,6 +1848,9 @@ def interpolate( copy: bool, **kwargs, ) -> Self: + if pa_version_under13p0: + raise NotImplementedError("interpolate requires pyarrow version > 12") + if method != "linear": raise NotImplementedError("Only method='linear' is implemented.") if limit_area is not None: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 5497d9d09e445..026ad8f2a7ae3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3006,12 +3006,18 @@ def test_string_to_datetime_parsing_cast(): tm.assert_series_equal(result, expected) +@pytest.mark.skipif( + pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow" +) def test_interpolate_not_numeric(data): if not data.dtype._is_numeric: with pytest.raises(ValueError, match="Values must be numeric."): pd.Series(data).interpolate() +@pytest.mark.skipif( + pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow" +) def test_interpolate_not_supported(): ser = pd.Series([1, None], dtype="int64[pyarrow]") with pytest.raises( @@ -3033,6 +3039,9 @@ def test_interpolate_not_supported(): ser.interpolate(limit_direction="backward") +@pytest.mark.skipif( + pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow" +) @pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"]) def test_interpolate(dtype): ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype) From bab1c999d016fbf02bd387910149c89084ac7d7d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:14:48 -0800 Subject: [PATCH 3/5] Fold into interpolate --- pandas/core/arrays/arrow/array.py | 56 ++++++++++--------------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0c4a8d1407597..c4207082c1245 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -199,10 +199,7 @@ def floordiv_compat( npt, ) - from pandas import ( - Index, - Series, - ) + from pandas import Series from pandas.core.arrays.datetimes import DatetimeArray from pandas.core.arrays.timedeltas import TimedeltaArray @@ -1956,40 +1953,6 @@ def _rank_calc( return result - def interpolate( - self, - *, - method: InterpolateOptions, - axis: int, - index: Index, - limit, - limit_direction, - limit_area, - copy: bool, - **kwargs, - ) -> Self: - if pa_version_under13p0: - raise NotImplementedError("interpolate requires pyarrow version > 12") - - if method != "linear": - raise NotImplementedError("Only method='linear' is implemented.") - if limit_area is not None: - raise NotImplementedError("Only limit_area=None is implemented.") - if limit is not None: - raise NotImplementedError("Only limit=0 is implemented.") - if limit_direction != "forward": - raise NotImplementedError("Only limit_direction='forward' is implemented.") - - if not self.dtype._is_numeric: - raise ValueError("Values must be numeric.") - - values = self._pa_array.combine_chunks() - na_value = pa.array([None], type=values.type) - y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2)) - prev_values = pa.concat_arrays([na_value, values[:-2], na_value]) - interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2)) - return type(self)(pc.coalesce(self._pa_array, interps)) - def _rank( self, *, @@ -2118,6 +2081,23 @@ def interpolate( See NDFrame.interpolate.__doc__. """ # NB: we return type(self) even if copy=False + if not self.dtype._is_numeric: + raise ValueError("Values must be numeric.") + + if ( + not pa_version_under13p0 + and method == "linear" + and limit_area is None + and limit is None + and limit_direction == "forward" + ): + values = self._pa_array.combine_chunks() + na_value = pa.array([None], type=values.type) + y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2)) + prev_values = pa.concat_arrays([na_value, values[:-2], na_value]) + interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2)) + return type(self)(pc.coalesce(self._pa_array, interps)) + mask = self.isna() if self.dtype.kind == "f": data = self._pa_array.to_numpy() From c1f69a21104864ed9ed83b7a2367bc04ca42bb14 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:15:39 -0800 Subject: [PATCH 4/5] Remove from 2.2 --- doc/source/whatsnew/v2.2.0.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 58b46061702b0..d9ab0452c8334 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -346,7 +346,6 @@ Other enhancements - Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`) - Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`) - Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`) -- Implemented :meth:`Series.interpolate` for :class:`ArrowDtype` (:issue:`56267`) - Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`) - Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as ``"BMS"`` (:issue:`56243`) - Improved error message when constructing :class:`Period` with invalid offsets such as ``"QS"`` (:issue:`55785`) From 83ee82fa2d0e7687bc9f91a6a605bb07b94c460e Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:25:22 -0800 Subject: [PATCH 5/5] Modify tests --- pandas/tests/extension/test_arrow.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d4d5b1fd68480..c6f850b03ca24 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3379,35 +3379,11 @@ def test_interpolate_not_numeric(data): pd.Series(data).interpolate() -@pytest.mark.skipif( - pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow" -) -def test_interpolate_not_supported(): - ser = pd.Series([1, None], dtype="int64[pyarrow]") - with pytest.raises( - NotImplementedError, match="Only method='linear' is implemented." - ): - ser.interpolate(method="akima") - - with pytest.raises( - NotImplementedError, match="Only limit_area=None is implemented." - ): - ser.interpolate(limit_area="inside") - - with pytest.raises(NotImplementedError, match="Only limit=0 is implemented."): - ser.interpolate(limit=1) - - with pytest.raises( - NotImplementedError, match="Only limit_direction='forward' is implemented." - ): - ser.interpolate(limit_direction="backward") - - @pytest.mark.skipif( pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow" ) @pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"]) -def test_interpolate(dtype): +def test_interpolate_linear(dtype): ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype) result = ser.interpolate() expected = pd.Series([None, 1, 2, 3, 4, None], dtype=dtype)