diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index c2cf6afc1a7b53..f3b11e52cdd7ad 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -483,6 +483,39 @@ def infer_dtype_from_array(arr, pandas_dtype=False): return arr.dtype, arr +def maybe_infer_dtype_type(element): + """Try to infer an object's dtype, for use in arithmetic ops + + Uses `element.dtype` if that's available. + Objects implementing the iterator protocol are cast to a NumPy array, + and from there the array's type is used. + + Parameters + ---------- + element : object + Possibly has a `.dtype` attribute, and possibly the iterator + protocol. + + Returns + ------- + tipo : type + + Examples + -------- + >>> from collections import namedtuple + >>> Foo = namedtuple("Foo", "dtype") + >>> maybe_infer_dtype_type(Foo(np.dtype("i8"))) + numpy.int64 + """ + tipo = None + if hasattr(element, 'dtype'): + tipo = element.dtype + elif is_list_like(element): + element = np.asarray(element) + tipo = element.dtype + return tipo + + def maybe_upcast(values, fill_value=np.nan, dtype=None, copy=False): """ provide explict type promotion and coercion diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 1fddf985f0cdbb..90de4ded18f8c7 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -44,7 +44,8 @@ soft_convert_objects, maybe_convert_objects, astype_nansafe, - find_common_type) + find_common_type, + maybe_infer_dtype_type) from pandas.core.dtypes.missing import ( isna, notna, array_equivalent, _isna_compat, @@ -629,10 +630,9 @@ def convert(self, copy=True, **kwargs): def _can_hold_element(self, element): """ require the same dtype as ourselves """ dtype = self.values.dtype.type - if is_list_like(element): - element = np.asarray(element) - tipo = element.dtype.type - return issubclass(tipo, dtype) + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return issubclass(tipo.type, dtype) return isinstance(element, dtype) def _try_cast_result(self, result, dtype=None): @@ -1806,11 +1806,10 @@ class FloatBlock(FloatOrComplexBlock): _downcast_dtype = 'int64' def _can_hold_element(self, element): - if is_list_like(element): - element = np.asarray(element) - tipo = element.dtype.type - return (issubclass(tipo, (np.floating, np.integer)) and - not issubclass(tipo, (np.datetime64, np.timedelta64))) + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return (issubclass(tipo.type, (np.floating, np.integer)) and + not issubclass(tipo.type, (np.datetime64, np.timedelta64))) return (isinstance(element, (float, int, np.floating, np.int_)) and not isinstance(element, (bool, np.bool_, datetime, timedelta, np.datetime64, np.timedelta64))) @@ -1856,9 +1855,9 @@ class ComplexBlock(FloatOrComplexBlock): is_complex = True def _can_hold_element(self, element): - if is_list_like(element): - element = np.array(element) - return issubclass(element.dtype.type, + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return issubclass(tipo.type, (np.floating, np.integer, np.complexfloating)) return (isinstance(element, (float, int, complex, np.float_, np.int_)) and @@ -1874,12 +1873,12 @@ class IntBlock(NumericBlock): _can_hold_na = False def _can_hold_element(self, element): - if is_list_like(element): - element = np.array(element) - tipo = element.dtype.type - return (issubclass(tipo, np.integer) and - not issubclass(tipo, (np.datetime64, np.timedelta64)) and - self.dtype.itemsize >= element.dtype.itemsize) + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return (issubclass(tipo.type, np.integer) and + not issubclass(tipo.type, (np.datetime64, + np.timedelta64)) and + self.dtype.itemsize >= tipo.itemsize) return is_integer(element) def should_store(self, value): @@ -1917,10 +1916,9 @@ def _box_func(self): return lambda x: tslib.Timedelta(x, unit='ns') def _can_hold_element(self, element): - if is_list_like(element): - element = np.array(element) - tipo = element.dtype.type - return issubclass(tipo, np.timedelta64) + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return issubclass(tipo.type, np.timedelta64) return isinstance(element, (timedelta, np.timedelta64)) def fillna(self, value, **kwargs): @@ -2018,9 +2016,9 @@ class BoolBlock(NumericBlock): _can_hold_na = False def _can_hold_element(self, element): - if is_list_like(element): - element = np.asarray(element) - return issubclass(element.dtype.type, np.bool_) + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + return issubclass(tipo.type, np.bool_) return isinstance(element, (bool, np.bool_)) def should_store(self, value): @@ -2450,7 +2448,9 @@ def _astype(self, dtype, mgr=None, **kwargs): return super(DatetimeBlock, self)._astype(dtype=dtype, **kwargs) def _can_hold_element(self, element): - if is_list_like(element): + tipo = maybe_infer_dtype_type(element) + if tipo is not None: + # TODO: this still uses asarray, instead of dtype.type element = np.array(element) return element.dtype == _NS_DTYPE or element.dtype == np.int64 return (is_integer(element) or isinstance(element, datetime) or diff --git a/pandas/tests/internals/test_internals.py b/pandas/tests/internals/test_internals.py index f40fc151676da1..c182db35c0c893 100644 --- a/pandas/tests/internals/test_internals.py +++ b/pandas/tests/internals/test_internals.py @@ -2,6 +2,7 @@ # pylint: disable=W0102 from datetime import datetime, date +import operator import sys import pytest import numpy as np @@ -1213,3 +1214,64 @@ def assert_add_equals(val, inc, result): with pytest.raises(ValueError): BlockPlacement(slice(2, None, -1)).add(-1) + + +class DummyElement(object): + def __init__(self, value, dtype): + self.value = value + self.dtype = np.dtype(dtype) + + def __array__(self): + return np.array(self.value, dtype=self.dtype) + + def __str__(self): + return "DummyElement({}, {})".format(self.value, self.dtype) + + def __repr__(self): + return str(self) + + def astype(self, dtype, copy=False): + self.dtype = dtype + return self + + def view(self, dtype): + return type(self)(self.value.view(dtype), dtype) + + def any(self, axis=None): + return bool(self.value) + + +class TestCanHoldElement(object): + @pytest.mark.parametrize('value, dtype', [ + (1, 'i8'), + (1.0, 'f8'), + (1j, 'complex128'), + (True, 'bool'), + (np.timedelta64(20, 'ns'), '