Skip to content

Commit

Permalink
BUG: replace coerces incorrect dtype
Browse files Browse the repository at this point in the history
closes pandas-dev#12747

Author: sinhrks <[email protected]>

This patch had conflicts when merged, resolved by
Committer: Jeff Reback <[email protected]>

Closes pandas-dev#12780 from sinhrks/replace_type and squashes the following commits:

f9154e8 [sinhrks] remove unnecessary comments
279fdf6 [sinhrks] remove import failure
de44877 [sinhrks] BUG: replace coerces incorrect dtype
  • Loading branch information
sinhrks authored and jreback committed Mar 20, 2017
1 parent b1e29db commit 8bde21a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 24 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ Bug Fixes


- Bug in the display of ``.info()`` where a qualifier (+) would always be displayed with a ``MultiIndex`` that contains only non-strings (:issue:`15245`)
- Bug in ``.replace()`` may result in incorrect dtypes. (:issue:`12747`)

- Bug in ``.asfreq()``, where frequency was not set for empty ``Series`` (:issue:`14320`)

Expand Down
20 changes: 17 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,8 +1894,11 @@ def convert(self, *args, **kwargs):
blocks.append(newb)

else:
values = fn(
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
values = fn(self.values.ravel(), **fn_kwargs)
try:
values = values.reshape(self.values.shape)
except NotImplementedError:
pass
blocks.append(make_block(values, ndim=self.ndim,
placement=self.mgr_locs))

Expand Down Expand Up @@ -3238,6 +3241,16 @@ def comp(s):
return _possibly_compare(values, getattr(s, 'asm8', s),
operator.eq)

def _cast_scalar(block, scalar):
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
if not is_dtype_equal(block.dtype, dtype):
dtype = _find_common_type([block.dtype, dtype])
block = block.astype(dtype)
# use original value
val = scalar

return block, val

masks = [comp(s) for i, s in enumerate(src_list)]

result_blocks = []
Expand All @@ -3260,7 +3273,8 @@ def comp(s):
# particular block
m = masks[i][b.mgr_locs.indexer]
if m.any():
new_rb.extend(b.putmask(m, d, inplace=True))
b, val = _cast_scalar(b, d)
new_rb.extend(b.putmask(m, val, inplace=True))
else:
new_rb.append(b)
rb = new_rb
Expand Down
50 changes: 39 additions & 11 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,12 +1153,27 @@ def setUp(self):
self.rep['float64'] = [1.1, 2.2]
self.rep['complex128'] = [1 + 1j, 2 + 2j]
self.rep['bool'] = [True, False]
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
pd.Timestamp('2011-01-03')]

for tz in ['UTC', 'US/Eastern']:
# to test tz => different tz replacement
key = 'datetime64[ns, {0}]'.format(tz)
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
pd.Timestamp('2011-01-03', tz=tz)]

self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
pd.Timedelta('2 day')]

def _assert_replace_conversion(self, from_key, to_key, how):
index = pd.Index([3, 4], name='xxx')
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
self.assertEqual(obj.dtype, from_key)

if (from_key.startswith('datetime') and to_key.startswith('datetime')):
# different tz, currently mask_missing raises SystemError
return

if how == 'dict':
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == 'series':
Expand All @@ -1175,17 +1190,12 @@ def _assert_replace_conversion(self, from_key, to_key, how):
pytest.skip("windows platform buggy: {0} -> {1}".format
(from_key, to_key))

if ((from_key == 'float64' and
to_key in ('bool', 'int64')) or

if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
(from_key == 'complex128' and
to_key in ('bool', 'int64', 'float64')) or

(from_key == 'int64' and
to_key in ('bool')) or

# TODO_GH12747 The result must be int?
(from_key == 'bool' and to_key == 'int64')):
# GH12747 The result must be int?
(from_key == 'int64' and to_key in ('bool'))):

# buggy on 32-bit
if tm.is_platform_32bit():
Expand Down Expand Up @@ -1248,13 +1258,31 @@ def test_replace_series_bool(self):
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64(self):
pass
from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64tz(self):
pass
from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_timedelta64(self):
pass
from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_period(self):
pass
4 changes: 2 additions & 2 deletions pandas/tests/series/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def check_replace(to_rep, val, expected):
tm.assert_series_equal(expected, r)
tm.assert_series_equal(expected, sc)

# should NOT upcast to float
e = pd.Series([0, 1, 2, 3, 4])
# MUST upcast to float
e = pd.Series([0., 1., 2., 3., 4.])
tr, v = [3], [3.0]
check_replace(tr, v, e)

Expand Down
37 changes: 29 additions & 8 deletions pandas/types/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_ensure_int32, _ensure_int64,
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
_POSSIBLY_CAST_DTYPES)
from .dtypes import ExtensionDtype
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
from .missing import isnull, notnull
from .inference import is_list_like
Expand Down Expand Up @@ -312,8 +312,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
return dtype, fill_value


def _infer_dtype_from_scalar(val):
""" interpret the dtype from a scalar """
def _infer_dtype_from_scalar(val, pandas_dtype=False):
"""
interpret the dtype from a scalar
Parameters
----------
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, scalar belongs to pandas extension types is inferred as
object
"""

dtype = np.object_

Expand All @@ -336,13 +345,20 @@ def _infer_dtype_from_scalar(val):

dtype = np.object_

elif isinstance(val, (np.datetime64,
datetime)) and getattr(val, 'tzinfo', None) is None:
val = lib.Timestamp(val).value
dtype = np.dtype('M8[ns]')
elif isinstance(val, (np.datetime64, datetime)):
val = tslib.Timestamp(val)
if val is tslib.NaT or val.tz is None:
dtype = np.dtype('M8[ns]')
else:
if pandas_dtype:
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
else:
# return datetimetz as object
return np.object_, val
val = val.value

elif isinstance(val, (np.timedelta64, timedelta)):
val = lib.Timedelta(val).value
val = tslib.Timedelta(val).value
dtype = np.dtype('m8[ns]')

elif is_bool(val):
Expand All @@ -363,6 +379,11 @@ def _infer_dtype_from_scalar(val):
elif is_complex(val):
dtype = np.complex_

elif pandas_dtype:
if lib.is_period(val):
dtype = PeriodDtype(freq=val.freq)
val = val.ordinal

return dtype, val


Expand Down

0 comments on commit 8bde21a

Please sign in to comment.