Skip to content

Commit

Permalink
REF (string): move ArrowStringArrayNumpySemantics methods to base cla…
Browse files Browse the repository at this point in the history
…ss (pandas-dev#59501)

* REF: move ArrowStringArrayNumpySemantics methods to parent class

* REF: move methods to ArrowStringArray

* mypy fixup

* Fix incorrect double-unpacking

* move methods to subclass
  • Loading branch information
jbrockmendel authored and WillAyd committed Aug 22, 2024
1 parent 122a56e commit 14d6804
Showing 1 changed file with 48 additions and 61 deletions.
109 changes: 48 additions & 61 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from functools import partial
import operator
import re
from typing import (
Expand Down Expand Up @@ -209,12 +208,17 @@ def dtype(self) -> StringDtype: # type: ignore[override]
return self._dtype

def insert(self, loc: int, item) -> ArrowStringArray:
if self.dtype.na_value is np.nan and item is np.nan:
item = libmissing.NA
if not isinstance(item, str) and item is not libmissing.NA:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@classmethod
def _result_converter(cls, values, na=None):
def _result_converter(self, values, na=None):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -494,11 +498,30 @@ def _str_get_dummies(self, sep: str = "|"):
return dummies.astype(np.int64, copy=False), labels

def _convert_int_dtype(self, result):
if self.dtype.na_value is np.nan:
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

return Int64Dtype().__from_arrow__(result)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if self.dtype.na_value is np.nan and name in ["any", "all"]:
if not skipna:
nas = pc.is_null(self._pa_array)
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
Expand Down Expand Up @@ -529,67 +552,31 @@ def _rank(
)
)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow"
_na_value = np.nan

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
# creates inheritance problems (Diamond inheritance)
if item in ArrowStringArrayMixin.__dict__ and item not in (
"_pa_array",
"__dict__",
):
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
else:
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
if self.dtype.na_value is np.nan:
res_values = result._values.to_numpy()
return result._constructor(
res_values, index=result.index, name=result.name, copy=False
)
return result

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True) -> Series:
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if name in ["any", "all"]:
if not skipna and name == "all":
nas = pc.invert(pc.is_null(self._pa_array))
arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, ""))
if self.dtype.na_value is np.nan:
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
arr = pc.not_equal(self._pa_array, "")
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)
else:
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
return result.to_numpy(np.bool_, na_value=False)
return result

def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics:
if item is np.nan:
item = libmissing.NA
return super().insert(loc, item) # type: ignore[return-value]

class ArrowStringArrayNumpySemantics(ArrowStringArray):
_na_value = np.nan
_str_get = ArrowStringArrayMixin._str_get
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_pad = ArrowStringArrayMixin._str_pad
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 comments on commit 14d6804

Please sign in to comment.