Skip to content

Commit

Permalink
BUG: pickling subset of Arrow-backed data would serialize the entire …
Browse files Browse the repository at this point in the history
…data (pandas-dev#49078)

* BUG: pickling subset of Arrow-backed data would serialize the entire data

* Use data
  • Loading branch information
mroeschke authored and noatamir committed Nov 9, 2022
1 parent df65eee commit f015b19
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ MultiIndex
I/O
^^^
- Bug in :func:`read_sas` caused fragmentation of :class:`DataFrame` and raised :class:`.errors.PerformanceWarning` (:issue:`48595`)
- Bug when a pickling a subset PyArrow-backed data that would serialize the entire data instead of the subset (:issue:`42600`)
- Bug in :func:`read_csv` for a single-line csv with fewer columns than ``names`` raised :class:`.errors.ParserError` with ``engine="c"`` (:issue:`47566`)
-

Expand Down
11 changes: 11 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
return type(self)(pc.abs_checked(self._data))

# GH 42600: __getstate__/__setstate__ not necessary once
# https://issues.apache.org/jira/browse/ARROW-10739 is addressed
def __getstate__(self):
state = self.__dict__.copy()
state["_data"] = self._data.combine_chunks()
return state

def __setstate__(self, state) -> None:
state["_data"] = pa.chunked_array(state["_data"])
self.__dict__.update(state)

def _cmp_method(self, other, op):
from pandas.arrays import BooleanArray

Expand Down
18 changes: 18 additions & 0 deletions pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import re

import numpy as np
Expand Down Expand Up @@ -197,3 +198,20 @@ def test_setitem_invalid_indexer_raises():

with pytest.raises(ValueError, match=None):
arr[[0, 1]] = ["foo", "bar", "baz"]


@skip_if_no_pyarrow
def test_pickle_roundtrip():
# GH 42600
expected = pd.Series(range(10), dtype="string[pyarrow]")
expected_sliced = expected.head(2)
full_pickled = pickle.dumps(expected)
sliced_pickled = pickle.dumps(expected_sliced)

assert len(full_pickled) > len(sliced_pickled)

result = pickle.loads(full_pickled)
tm.assert_series_equal(result, expected)

result_sliced = pickle.loads(sliced_pickled)
tm.assert_series_equal(result_sliced, expected_sliced)
17 changes: 17 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BytesIO,
StringIO,
)
import pickle

import numpy as np
import pytest
Expand Down Expand Up @@ -1347,3 +1348,19 @@ def test_is_bool_dtype():
result = s[data]
expected = s[np.asarray(data)]
tm.assert_series_equal(result, expected)


def test_pickle_roundtrip(data):
# GH 42600
expected = pd.Series(data)
expected_sliced = expected.head(2)
full_pickled = pickle.dumps(expected)
sliced_pickled = pickle.dumps(expected_sliced)

assert len(full_pickled) > len(sliced_pickled)

result = pickle.loads(full_pickled)
tm.assert_series_equal(result, expected)

result_sliced = pickle.loads(sliced_pickled)
tm.assert_series_equal(result_sliced, expected_sliced)

0 comments on commit f015b19

Please sign in to comment.