Skip to content

Commit

Permalink
REF: combine Block _can_hold_element methods (pandas-dev#40709)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Apr 1, 2021
1 parent 65860fa commit e1dd032
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 72 deletions.
54 changes: 50 additions & 4 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
IntervalArray,
PeriodArray,
TimedeltaArray,
)

_int8_max = np.iinfo(np.int8).max
Expand Down Expand Up @@ -2169,32 +2172,68 @@ def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
raise ValueError(f"Cannot assign {type(value).__name__} to bool series")


def can_hold_element(dtype: np.dtype, element: Any) -> bool:
def can_hold_element(arr: ArrayLike, element: Any) -> bool:
"""
Can we do an inplace setitem with this element in an array with this dtype?
Parameters
----------
dtype : np.dtype
arr : np.ndarray or ExtensionArray
element : Any
Returns
-------
bool
"""
dtype = arr.dtype
if not isinstance(dtype, np.dtype) or dtype.kind in ["m", "M"]:
if isinstance(dtype, (PeriodDtype, IntervalDtype, DatetimeTZDtype, np.dtype)):
# np.dtype here catches datetime64ns and timedelta64ns; we assume
# in this case that we have DatetimeArray/TimedeltaArray
arr = cast(
"PeriodArray | DatetimeArray | TimedeltaArray | IntervalArray", arr
)
try:
arr._validate_setitem_value(element)
return True
except (ValueError, TypeError):
return False

# This is technically incorrect, but maintains the behavior of
# ExtensionBlock._can_hold_element
return True

tipo = maybe_infer_dtype_type(element)

if dtype.kind in ["i", "u"]:
if tipo is not None:
return tipo.kind in ["i", "u"] and dtype.itemsize >= tipo.itemsize
if tipo.kind not in ["i", "u"]:
# Anything other than integer we cannot hold
return False
elif dtype.itemsize < tipo.itemsize:
return False
elif not isinstance(tipo, np.dtype):
# i.e. nullable IntegerDtype; we can put this into an ndarray
# losslessly iff it has no NAs
return not element._mask.any()
return True

# We have not inferred an integer from the dtype
# check if we have a builtin int or a float equal to an int
return is_integer(element) or (is_float(element) and element.is_integer())

elif dtype.kind == "f":
if tipo is not None:
return tipo.kind in ["f", "i", "u"]
# TODO: itemsize check?
if tipo.kind not in ["f", "i", "u"]:
# Anything other than float/integer we cannot hold
return False
elif not isinstance(tipo, np.dtype):
# i.e. nullable IntegerDtype or FloatingDtype;
# we can put this into an ndarray losslessly iff it has no NAs
return not element._mask.any()
return True

return lib.is_integer(element) or lib.is_float(element)

elif dtype.kind == "c":
Expand All @@ -2212,4 +2251,11 @@ def can_hold_element(dtype: np.dtype, element: Any) -> bool:
elif dtype == object:
return True

elif dtype.kind == "S":
# TODO: test tests.frame.methods.test_replace tests get here,
# need more targeted tests. xref phofl has a PR about this
if tipo is not None:
return tipo.kind == "S" and tipo.itemsize <= dtype.itemsize
return isinstance(element, bytes) and len(element) <= dtype.itemsize

raise NotImplementedError(dtype)
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4482,7 +4482,7 @@ def _validate_fill_value(self, value):
TypeError
If the value cannot be inserted into an array of this dtype.
"""
if not can_hold_element(self.dtype, value):
if not can_hold_element(self._values, value):
raise TypeError
return value

Expand Down
67 changes: 10 additions & 57 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
Tuple,
Expand All @@ -18,8 +17,6 @@
import numpy as np

from pandas._libs import (
Interval,
Period,
Timestamp,
algos as libalgos,
internals as libinternals,
Expand Down Expand Up @@ -102,6 +99,7 @@
PeriodArray,
TimedeltaArray,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.base import PandasObject
import pandas.core.common as com
import pandas.core.computation.expressions as expressions
Expand All @@ -122,7 +120,6 @@
Float64Index,
Index,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray

# comparison is faster than is_object_dtype
_dtype_obj = np.dtype("object")
Expand Down Expand Up @@ -625,9 +622,11 @@ def convert(
"""
return [self.copy()] if copy else [self]

@final
def _can_hold_element(self, element: Any) -> bool:
""" require the same dtype as ourselves """
raise NotImplementedError("Implemented on subclasses")
element = extract_array(element, extract_numpy=True)
return can_hold_element(self.values, element)

@final
def should_store(self, value: ArrayLike) -> bool:
Expand Down Expand Up @@ -1545,7 +1544,7 @@ def setitem(self, indexer, value):
be a compatible shape.
"""
if not self._can_hold_element(value):
# This is only relevant for DatetimeTZBlock, ObjectValuesExtensionBlock,
# This is only relevant for DatetimeTZBlock, PeriodDtype, IntervalDtype,
# which has a non-trivial `_can_hold_element`.
# https://github.com/pandas-dev/pandas/issues/24020
# Need a dedicated setitem until GH#24020 (type promotion in setitem
Expand Down Expand Up @@ -1597,10 +1596,6 @@ def take_nd(

return self.make_block_same_class(new_values, new_mgr_locs)

def _can_hold_element(self, element: Any) -> bool:
# TODO: We may need to think about pushing this onto the array.
return True

def _slice(self, slicer):
"""
Return a slice of my values.
Expand Down Expand Up @@ -1746,54 +1741,22 @@ def _unstack(self, unstacker, fill_value, new_placement):
return blocks, mask


class HybridMixin:
"""
Mixin for Blocks backed (maybe indirectly) by ExtensionArrays.
"""

array_values: Callable

def _can_hold_element(self, element: Any) -> bool:
values = self.array_values

try:
# error: "Callable[..., Any]" has no attribute "_validate_setitem_value"
values._validate_setitem_value(element) # type: ignore[attr-defined]
return True
except (ValueError, TypeError):
return False


class ObjectValuesExtensionBlock(HybridMixin, ExtensionBlock):
"""
Block providing backwards-compatibility for `.values`.
Used by PeriodArray and IntervalArray to ensure that
Series[T].values is an ndarray of objects.
"""

pass


class NumericBlock(Block):
__slots__ = ()
is_numeric = True

def _can_hold_element(self, element: Any) -> bool:
element = extract_array(element, extract_numpy=True)
if isinstance(element, (IntegerArray, FloatingArray)):
if element._mask.any():
return False
return can_hold_element(self.dtype, element)


class NDArrayBackedExtensionBlock(HybridMixin, Block):
class NDArrayBackedExtensionBlock(Block):
"""
Block backed by an NDArrayBackedExtensionArray
"""

values: NDArrayBackedExtensionArray

@property
def array_values(self) -> NDArrayBackedExtensionArray:
return self.values

@property
def is_view(self) -> bool:
""" return a boolean if I am possibly a view """
Expand Down Expand Up @@ -1901,10 +1864,6 @@ class DatetimeLikeBlockMixin(NDArrayBackedExtensionBlock):

is_numeric = False

@cache_readonly
def array_values(self):
return self.values


class DatetimeBlock(DatetimeLikeBlockMixin):
__slots__ = ()
Expand All @@ -1920,7 +1879,6 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeLikeBlockMixin):
is_numeric = False

internal_values = Block.internal_values
_can_hold_element = DatetimeBlock._can_hold_element
diff = DatetimeBlock.diff
where = DatetimeBlock.where
putmask = DatetimeLikeBlockMixin.putmask
Expand Down Expand Up @@ -1983,9 +1941,6 @@ def convert(
res_values = ensure_block_shape(res_values, self.ndim)
return [self.make_block(res_values)]

def _can_hold_element(self, element: Any) -> bool:
return True


class CategoricalBlock(ExtensionBlock):
# this Block type is kept for backwards-compatibility
Expand Down Expand Up @@ -2052,8 +2007,6 @@ def get_block_type(values, dtype: Optional[Dtype] = None):
cls = CategoricalBlock
elif vtype is Timestamp:
cls = DatetimeTZBlock
elif vtype is Interval or vtype is Period:
cls = ObjectValuesExtensionBlock
elif isinstance(dtype, ExtensionDtype):
# Note: need to be sure PandasArray is unwrapped before we get here
cls = ExtensionBlock
Expand Down
9 changes: 0 additions & 9 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
CategoricalBlock,
DatetimeTZBlock,
ExtensionBlock,
ObjectValuesExtensionBlock,
ensure_block_shape,
extend_blocks,
get_block_type,
Expand Down Expand Up @@ -1841,14 +1840,6 @@ def _form_blocks(

blocks.extend(external_blocks)

if len(items_dict["ObjectValuesExtensionBlock"]):
external_blocks = [
new_block(array, klass=ObjectValuesExtensionBlock, placement=i, ndim=2)
for i, array in items_dict["ObjectValuesExtensionBlock"]
]

blocks.extend(external_blocks)

if len(extra_locs):
shape = (len(extra_locs),) + tuple(len(x) for x in axes[1:])

Expand Down
13 changes: 12 additions & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas.util._test_decorators as td

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.dtypes import (
ExtensionDtype,
PandasDtype,
Expand All @@ -27,7 +28,10 @@
import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.numpy_ import PandasArray
from pandas.core.internals import managers
from pandas.core.internals import (
blocks,
managers,
)
from pandas.tests.extension import base

# TODO(ArrayManager) PandasArray
Expand All @@ -45,6 +49,12 @@ def _extract_array_patched(obj):
return obj


def _can_hold_element_patched(obj, element) -> bool:
if isinstance(element, PandasArray):
element = element.to_numpy()
return can_hold_element(obj, element)


@pytest.fixture(params=["float", "object"])
def dtype(request):
return PandasDtype(np.dtype(request.param))
Expand All @@ -70,6 +80,7 @@ def allow_in_pandas(monkeypatch):
with monkeypatch.context() as m:
m.setattr(PandasArray, "_typ", "extension")
m.setattr(managers, "_extract_array", _extract_array_patched)
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
yield


Expand Down

0 comments on commit e1dd032

Please sign in to comment.