Skip to content

Commit

Permalink
CLN: deduplicate __setitem__ and _reduce on masked arrays (#34187)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored May 20, 2020
1 parent 23c7e85 commit 6ad157e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 73 deletions.
39 changes: 4 additions & 35 deletions pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
is_integer_dtype,
is_list_like,
is_numeric_dtype,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
from pandas.core.dtypes.missing import isna

from pandas.core import nanops, ops
from pandas.core.array_algos import masked_reductions
from pandas.core.indexers import check_array_indexer
from pandas.core import ops

from .masked import BaseMaskedArray, BaseMaskedDtype

Expand Down Expand Up @@ -347,19 +344,8 @@ def reconstruct(x):
else:
return reconstruct(result)

def __setitem__(self, key, value) -> None:
_is_scalar = is_scalar(value)
if _is_scalar:
value = [value]
value, mask = coerce_to_array(value)

if _is_scalar:
value = value[0]
mask = mask[0]

key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask
def _coerce_to_array(self, value) -> Tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value)

def astype(self, dtype, copy: bool = True) -> ArrayLike:
"""
Expand Down Expand Up @@ -670,24 +656,7 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
if name in {"any", "all"}:
return getattr(self, name)(skipna=skipna, **kwargs)

data = self._data
mask = self._mask

if name in {"sum", "prod", "min", "max"}:
op = getattr(masked_reductions, name)
return op(data, mask, skipna=skipna, **kwargs)

# coerce to a nan-aware float if needed
if self._hasna:
data = self.to_numpy("float64", na_value=np.nan)

op = getattr(nanops, "nan" + name)
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)

if np.isnan(result):
return libmissing.NA

return result
return super()._reduce(name, skipna, **kwargs)

def _maybe_mask_result(self, result, mask, other, op_name: str):
"""
Expand Down
40 changes: 3 additions & 37 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
is_integer_dtype,
is_list_like,
is_object_dtype,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.core.dtypes.missing import isna

from pandas.core import nanops, ops
from pandas.core import ops
from pandas.core.array_algos import masked_reductions
from pandas.core.indexers import check_array_indexer
from pandas.core.ops import invalid_comparison
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.tools.numeric import to_numeric
Expand Down Expand Up @@ -417,19 +415,8 @@ def reconstruct(x):
else:
return reconstruct(result)

def __setitem__(self, key, value) -> None:
_is_scalar = is_scalar(value)
if _is_scalar:
value = [value]
value, mask = coerce_to_array(value, dtype=self.dtype)

if _is_scalar:
value = value[0]
mask = mask[0]

key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask
def _coerce_to_array(self, value) -> Tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value, dtype=self.dtype)

def astype(self, dtype, copy: bool = True) -> ArrayLike:
"""
Expand Down Expand Up @@ -553,27 +540,6 @@ def cmp_method(self, other):
name = f"__{op.__name__}__"
return set_function_name(cmp_method, name, cls)

def _reduce(self, name: str, skipna: bool = True, **kwargs):
data = self._data
mask = self._mask

if name in {"sum", "prod", "min", "max"}:
op = getattr(masked_reductions, name)
return op(data, mask, skipna=skipna, **kwargs)

# coerce to a nan-aware float if needed
# (we explicitly use NaN within reductions)
if self._hasna:
data = self.to_numpy("float64", na_value=np.nan)

op = getattr(nanops, "nan" + name)
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)

if np.isnan(result):
return libmissing.NA

return result

def sum(self, skipna=True, min_count=0, **kwargs):
nv.validate_sum((), kwargs)
result = masked_reductions.sum(
Expand Down
47 changes: 46 additions & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
from pandas.util._decorators import doc

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import is_integer, is_object_dtype, is_string_dtype
from pandas.core.dtypes.common import (
is_integer,
is_object_dtype,
is_scalar,
is_string_dtype,
)
from pandas.core.dtypes.missing import isna, notna

from pandas.core import nanops
from pandas.core.algorithms import _factorize_array, take
from pandas.core.array_algos import masked_reductions
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
from pandas.core.indexers import check_array_indexer

Expand Down Expand Up @@ -77,6 +84,23 @@ def __getitem__(self, item):

return type(self)(self._data[item], self._mask[item])

def _coerce_to_array(self, values) -> Tuple[np.ndarray, np.ndarray]:
raise AbstractMethodError(self)

def __setitem__(self, key, value) -> None:
_is_scalar = is_scalar(value)
if _is_scalar:
value = [value]
value, mask = self._coerce_to_array(value)

if _is_scalar:
value = value[0]
mask = mask[0]

key = check_array_indexer(self, key)
self._data[key] = value
self._mask[key] = mask

def __iter__(self):
for i in range(len(self)):
if self._mask[i]:
Expand Down Expand Up @@ -305,3 +329,24 @@ def value_counts(self, dropna: bool = True) -> "Series":
counts = IntegerArray(counts, mask)

return Series(counts, index=index)

def _reduce(self, name: str, skipna: bool = True, **kwargs):
data = self._data
mask = self._mask

if name in {"sum", "prod", "min", "max"}:
op = getattr(masked_reductions, name)
return op(data, mask, skipna=skipna, **kwargs)

# coerce to a nan-aware float if needed
# (we explicitly use NaN within reductions)
if self._hasna:
data = self.to_numpy("float64", na_value=np.nan)

op = getattr(nanops, "nan" + name)
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)

if np.isnan(result):
return libmissing.NA

return result

0 comments on commit 6ad157e

Please sign in to comment.