Skip to content

Commit

Permalink
ENH: Implement Series.interpolate for ArrowDtype (#56347)
Browse files Browse the repository at this point in the history
* ENH: Implement Series.interpolate for ArrowDtype

* Min version compat

* Fold into interpolate

* Remove from 2.2

* Modify tests
  • Loading branch information
mroeschke authored Feb 8, 2024
1 parent 4d3964e commit 2110b74
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
17 changes: 17 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,23 @@ def interpolate(
See NDFrame.interpolate.__doc__.
"""
# NB: we return type(self) even if copy=False
if not self.dtype._is_numeric:
raise ValueError("Values must be numeric.")

if (
not pa_version_under13p0
and method == "linear"
and limit_area is None
and limit is None
and limit_direction == "forward"
):
values = self._pa_array.combine_chunks()
na_value = pa.array([None], type=values.type)
y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2))
prev_values = pa.concat_arrays([na_value, values[:-2], na_value])
interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2))
return type(self)(pc.coalesce(self._pa_array, interps))

mask = self.isna()
if self.dtype.kind == "f":
data = self._pa_array.to_numpy()
Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3436,6 +3436,26 @@ def test_string_to_datetime_parsing_cast():
tm.assert_series_equal(result, expected)


@pytest.mark.skipif(
pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow"
)
def test_interpolate_not_numeric(data):
if not data.dtype._is_numeric:
with pytest.raises(ValueError, match="Values must be numeric."):
pd.Series(data).interpolate()


@pytest.mark.skipif(
pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow"
)
@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"])
def test_interpolate_linear(dtype):
ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype)
result = ser.interpolate()
expected = pd.Series([None, 1, 2, 3, 4, None], dtype=dtype)
tm.assert_series_equal(result, expected)


def test_string_to_time_parsing_cast():
# GH 56463
string_times = ["11:41:43.076160"]
Expand Down

0 comments on commit 2110b74

Please sign in to comment.