Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Added fill_value for unstack (pydata#3541)
Browse files Browse the repository at this point in the history
* Added fill_value for unstack

* remove sparse option and fix unintended changes

* a bug fix

* using assert_equal

* assert_equals -> assert_equal
  • Loading branch information
fujiisoup authored Nov 16, 2019
1 parent 52d4845 commit 56c16e4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ Breaking changes
New Features
~~~~~~~~~~~~

- Added the ``fill_value`` option to :py:meth:`~xarray.DataArray.unstack` and
:py:meth:`~xarray.Dataset.unstack` (:issue:`3518`).
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Added the ``max_gap`` kwarg to :py:meth:`~xarray.DataArray.interpolate_na` and
:py:meth:`~xarray.Dataset.interpolate_na`. This controls the maximum size of the data
gap that will be filled by interpolation. By `Deepak Cherian <https://github.com/dcherian>`_.
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,9 @@ def stack(
return self._from_temp_dataset(ds)

def unstack(
self, dim: Union[Hashable, Sequence[Hashable], None] = None
self,
dim: Union[Hashable, Sequence[Hashable], None] = None,
fill_value: Any = dtypes.NA,
) -> "DataArray":
"""
Unstack existing dimensions corresponding to MultiIndexes into
Expand All @@ -1739,6 +1741,7 @@ def unstack(
dim : hashable or sequence of hashable, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
Returns
-------
Expand Down Expand Up @@ -1770,7 +1773,7 @@ def unstack(
--------
DataArray.stack
"""
ds = self._to_temp_dataset().unstack(dim)
ds = self._to_temp_dataset().unstack(dim, fill_value)
return self._from_temp_dataset(ds)

def to_unstacked_dataset(self, dim, level=0):
Expand Down
13 changes: 9 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3333,7 +3333,7 @@ def ensure_stackable(val):

return data_array

def _unstack_once(self, dim: Hashable) -> "Dataset":
def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
index = self.get_index(dim)
index = index.remove_unused_levels()
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
Expand All @@ -3342,7 +3342,7 @@ def _unstack_once(self, dim: Hashable) -> "Dataset":
if index.equals(full_idx):
obj = self
else:
obj = self.reindex({dim: full_idx}, copy=False)
obj = self.reindex({dim: full_idx}, copy=False, fill_value=fill_value)

new_dim_names = index.names
new_dim_sizes = [lev.size for lev in index.levels]
Expand All @@ -3368,7 +3368,11 @@ def _unstack_once(self, dim: Hashable) -> "Dataset":
variables, coord_names=coord_names, indexes=indexes
)

def unstack(self, dim: Union[Hashable, Iterable[Hashable]] = None) -> "Dataset":
def unstack(
self,
dim: Union[Hashable, Iterable[Hashable]] = None,
fill_value: Any = dtypes.NA,
) -> "Dataset":
"""
Unstack existing dimensions corresponding to MultiIndexes into
multiple new dimensions.
Expand All @@ -3380,6 +3384,7 @@ def unstack(self, dim: Union[Hashable, Iterable[Hashable]] = None) -> "Dataset":
dim : Hashable or iterable of Hashable, optional
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
Returns
-------
Expand Down Expand Up @@ -3417,7 +3422,7 @@ def unstack(self, dim: Union[Hashable, Iterable[Hashable]] = None) -> "Dataset":

result = self.copy(deep=False)
for dim in dims:
result = result._unstack_once(dim)
result = result._unstack_once(dim, fill_value)
return result

def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset":
Expand Down
17 changes: 17 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,6 +2794,23 @@ def test_unstack_errors(self):
with raises_regex(ValueError, "do not have a MultiIndex"):
ds.unstack("x")

def test_unstack_fill_value(self):
ds = xr.Dataset(
{"var": (("x",), np.arange(6))},
coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)},
)
# make ds incomplete
ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"])
# test fill_value
actual = ds.unstack("index", fill_value=-1)
expected = ds.unstack("index").fillna(-1).astype(np.int)
assert actual["var"].dtype == np.int
assert_equal(actual, expected)

actual = ds["var"].unstack("index", fill_value=-1)
expected = ds["var"].unstack("index").fillna(-1).astype(np.int)
assert actual.equals(expected)

def test_stack_unstack_fast(self):
ds = Dataset(
{
Expand Down

0 comments on commit 56c16e4

Please sign in to comment.