Skip to content

Commit

Permalink
REF: implement putmask for CI/DTI/TDI/PI (#36400)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Sep 17, 2020
1 parent d4947a9 commit 52c81a9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 7 deletions.
5 changes: 5 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,11 @@ def map(self, mapper):
# -------------------------------------------------------------
# Validators; ideally these can be de-duplicated

def _validate_where_value(self, value):
if is_scalar(value):
return self._validate_fill_value(value)
return self._validate_listlike(value)

def _validate_insert_value(self, value) -> int:
code = self.categories.get_indexer([value])
if (code == -1) and not (is_scalar(value) and isna(value)):
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4232,9 +4232,6 @@ def putmask(self, mask, value):
try:
converted = self._validate_fill_value(value)
np.putmask(values, mask, converted)
if is_period_dtype(self.dtype):
# .values cast to object, so we need to cast back
values = type(self)(values)._data
return self._shallow_copy(values)
except (ValueError, TypeError) as err:
if is_object_dtype(self):
Expand Down
11 changes: 11 additions & 0 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,17 @@ def where(self, cond, other=None):
cat = Categorical(values, dtype=self.dtype)
return type(self)._simple_new(cat, name=self.name)

def putmask(self, mask, value):
try:
code_value = self._data._validate_where_value(value)
except (TypeError, ValueError):
return self.astype(object).putmask(mask, value)

codes = self._data._ndarray.copy()
np.putmask(codes, mask, code_value)
cat = self._data._from_backing_data(codes)
return type(self)._simple_new(cat, name=self.name)

def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
"""
Create index with target's values (move/add/delete values as necessary)
Expand Down
13 changes: 12 additions & 1 deletion pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,18 @@ def where(self, cond, other=None):
raise TypeError(f"Where requires matching dtype, not {oth}") from err

result = np.where(cond, values, other).astype("i8")
arr = type(self._data)._simple_new(result, dtype=self.dtype)
arr = self._data._from_backing_data(result)
return type(self)._simple_new(arr, name=self.name)

def putmask(self, mask, value):
try:
value = self._data._validate_where_value(value)
except (TypeError, ValueError):
return self.astype(object).putmask(mask, value)

result = self._data._ndarray.copy()
np.putmask(result, mask, value)
arr = self._data._from_backing_data(result)
return type(self)._simple_new(arr, name=self.name)

def _summary(self, name=None) -> str:
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,16 +846,17 @@ def test_map_str(self):
def test_putmask_with_wrong_mask(self):
# GH18368
index = self.create_index()
fill = index[0]

msg = "putmask: mask and data must be the same size"
with pytest.raises(ValueError, match=msg):
index.putmask(np.ones(len(index) + 1, np.bool_), 1)
index.putmask(np.ones(len(index) + 1, np.bool_), fill)

with pytest.raises(ValueError, match=msg):
index.putmask(np.ones(len(index) - 1, np.bool_), 1)
index.putmask(np.ones(len(index) - 1, np.bool_), fill)

with pytest.raises(ValueError, match=msg):
index.putmask("foo", 1)
index.putmask("foo", fill)

@pytest.mark.parametrize("copy", [True, False])
@pytest.mark.parametrize("name", [None, "foo"])
Expand Down

0 comments on commit 52c81a9

Please sign in to comment.