Skip to content

Commit

Permalink
ENH: EA._from_scalars (pandas-dev#53089)
Browse files Browse the repository at this point in the history
* ENH: BaseStringArray._from_scalars

* WIP: EA._from_scalars

* ENH: implement EA._from_scalars

* Fix StringDtype/CategoricalDtype combine

* mypy fixup
  • Loading branch information
jbrockmendel authored Oct 16, 2023
1 parent 32c9c8f commit 746e5ee
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 17 deletions.
33 changes: 33 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
FillnaOptions,
InterpolateOptions,
NumpySorter,
Expand Down Expand Up @@ -293,6 +294,38 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal
"""
raise AbstractMethodError(cls)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
"""
Strict analogue to _from_sequence, allowing only sequences of scalars
that should be specifically inferred to the given dtype.
Parameters
----------
scalars : sequence
dtype : ExtensionDtype
Raises
------
TypeError or ValueError
Notes
-----
This is called in a try/except block when casting the result of a
pointwise operation.
"""
try:
return cls._from_sequence(scalars, dtype=dtype, copy=False)
except (ValueError, TypeError):
raise
except Exception:
warnings.warn(
"_from_scalars should only raise ValueError or TypeError. "
"Consider overriding _from_scalars where appropriate.",
stacklevel=find_stack_level(),
)
raise

@classmethod
def _from_sequence_of_strings(
cls, strings, *, dtype: Dtype | None = None, copy: bool = False
Expand Down
17 changes: 17 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
NpDtype,
Ordered,
Self,
Expand Down Expand Up @@ -509,6 +510,22 @@ def _from_sequence(
) -> Self:
return cls(scalars, dtype=dtype, copy=copy)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if dtype is None:
# The _from_scalars strictness doesn't make much sense in this case.
raise NotImplementedError

res = cls._from_sequence(scalars, dtype=dtype)

# if there are any non-category elements in scalars, these will be
# converted to NAs in res.
mask = isna(scalars)
if not (mask == res.isna()).all():
# Some non-category element in scalars got converted to NA in res.
raise ValueError
return res

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

from pandas._typing import (
DateTimeErrorChoices,
DtypeObj,
IntervalClosedType,
Self,
TimeAmbiguous,
Expand Down Expand Up @@ -266,6 +267,14 @@ def _scalar_type(self) -> type[Timestamp]:
_freq: BaseOffset | None = None
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
# TODO: require any NAs be valid-for-DTA
# TODO: if dtype is passed, check for tzawareness compat?
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pandas._typing import (
AxisInt,
Dtype,
DtypeObj,
NumpySorter,
NumpyValueArrayLike,
Scalar,
Expand Down Expand Up @@ -253,6 +254,13 @@ def tolist(self):
return [x.tolist() for x in self]
return list(self.to_numpy())

@classmethod
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) != "string":
# TODO: require any NAs be valid-for-string
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down
26 changes: 12 additions & 14 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,11 @@ def maybe_cast_pointwise_result(
"""

if isinstance(dtype, ExtensionDtype):
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
# TODO: avoid this special-casing
# We have to special case categorical so as not to upcast
# things like counts back to categorical

cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)
cls = dtype.construct_array_type()
if same_dtype:
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
else:
result = _maybe_cast_to_extension_array(cls, result)

elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
result = maybe_downcast_to_dtype(result, dtype)
Expand All @@ -498,11 +493,14 @@ def _maybe_cast_to_extension_array(
-------
ExtensionArray or obj
"""
from pandas.core.arrays.string_ import BaseStringArray
result: ArrayLike

# Everything can be converted to StringArrays, but we may not want to convert
if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string":
return obj
if dtype is not None:
try:
result = cls._from_scalars(obj, dtype=dtype)
except (TypeError, ValueError):
return obj
return result

try:
result = cls._from_sequence(obj, dtype=dtype)
Expand Down
13 changes: 11 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
Expand All @@ -100,6 +103,7 @@
from pandas.core.arrays.arrow import StructAccessor
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import (
extract_array,
sanitize_array,
Expand Down Expand Up @@ -3377,7 +3381,12 @@ def combine(

# try_float=False is to match agg_series
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
# same_dtype here is a kludge to avoid casting e.g. [True, False] to
# ["True", "False"]
same_dtype = isinstance(self.dtype, (StringDtype, CategoricalDtype))
res_values = maybe_cast_pointwise_result(
npvalues, self.dtype, same_dtype=same_dtype
)
return self._constructor(res_values, index=new_index, name=new_name, copy=False)

def combine_first(self, other) -> Series:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/resample/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex():
index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"),
)
expected = expected.reindex(["Group_obj", "Group"], axis=1)
expected["Group"] = expected["Group_obj"]
expected["Group"] = expected["Group_obj"].astype("category")
tm.assert_frame_equal(result, expected)


Expand Down

0 comments on commit 746e5ee

Please sign in to comment.