Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray.map #4484

Closed
wants to merge 14 commits into from
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Top-level functions
corr
dot
polyval
map
map_blocks
show_versions
set_options
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ New Features
now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`).
By `Miguel Jimenez <https://github.com/Mikejmnez>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.

- New function :py:func:`map` introduced as a generalization of :py:meth:`Dataset.map`
for the case when the mapping function has more than one DataArray as input.
By `Andras Gefferth <https://github.com/kefirbandi>`_.

Bug fixes
~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .core.dataarray import DataArray
from .core.dataset import Dataset
from .core.extensions import register_dataarray_accessor, register_dataset_accessor
from .core.map import map
from .core.merge import MergeError, merge
from .core.options import set_options
from .core.parallel import map_blocks
Expand Down Expand Up @@ -60,6 +61,7 @@
"infer_freq",
"load_dataarray",
"load_dataset",
"map",
"map_blocks",
"merge",
"ones_like",
Expand Down
17 changes: 8 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
infix_dims,
is_dict_like,
is_scalar,
maybe_wrap_array,
)
from .variable import (
IndexVariable,
Expand Down Expand Up @@ -4393,15 +4392,15 @@ def map(
Data variables:
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 1.0 2.0

See also
--------
xarray.map
"""
variables = {
k: maybe_wrap_array(v, func(v, *args, **kwargs))
for k, v in self.data_vars.items()
}
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
attrs = self.attrs if keep_attrs else None
return type(self)(variables, attrs=attrs)
from xarray.core.map import map
_keep_attrs = 0 if keep_attrs else None
res = map([self], func, _keep_attrs, args, kwargs)
return res

def apply(
self,
Expand Down
107 changes: 107 additions & 0 deletions xarray/core/map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
Optional,
)

from .options import _get_keep_attrs
from .utils import (
maybe_wrap_array,
)
from .. import Dataset


def map(
datasets: Iterable[Any],
func: Callable,
keep_attrs: Optional[int] = None,
args: Iterable[Any] = (),
kwargs: Dict = None,
) -> "Dataset":
"""Apply a function to each variable in the provided dataset(s).

The function may take several DataArrays as inputs. The number of DataArrays
passed to the function will be equal to the length of the datasets variable.

It is assumed that the Datasets in the datasets variable share common data variable names.
If the same variable name is present in all Datasets, then the function will be performed on
those DataArrays

Parameters
----------
datasets : sequence of Datasets
The Dataset whose variables will be the input DataArrays of the function
func : callable
Function which can be called in the form `func(x,y,z, ..., *args, **kwargs)`
to transform each sequence of DataArrays `x`, `y`, `z` in the datasets into another
DataArray.
keep_attrs : int or bool, optional
If False, the new object will be returned without attributes.
If is an integer between 0 and len(datasets-1), it will give the index of the Dataset in
datasets parameter whose attributes needs to be copied
args : tuple, optional
Positional arguments passed on to `func`.
kwargs : dict, optional
Keyword arguments passed on to `func`.

Returns
-------
applied : Dataset
Resulting dataset from applying ``func`` to each tuple of data variables.

Examples
--------
>>> da = xr.DataArray([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
>>> ds1 = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
>>> ds2 = xr.Dataset({"foo": da+1, "bar": ("x", [-1, 2])})
>>> ds1
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.1 2.2 3.3 4.4 5.5 6.6
bar (x) int64 -1 2
>>> ds2
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 2.1 3.2 4.3 5.4 6.5 7.6
bar (x) int64 -1 2
>>> f = lambda a, b: b-a
>>> map([ds1, ds2], f)
<xarray.Dataset>
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
Dimensions without coordinates: dim_0, dim_1, x
Data variables:
foo (dim_0, dim_1) float64 1.0 1.0 1.0 1.0 1.0 1.0
bar (x) int64 0 0


See Also
--------
Dataset.map
"""
if kwargs is None:
kwargs = {}
variables = {}
if len(datasets):
shared_variable_names = set.intersection(*(set(ds.data_vars) for ds in datasets))
for k in shared_variable_names:
data_arrays = [d[k] for d in datasets]
v = maybe_wrap_array(datasets[0][k], func(*(data_arrays + list(args)), **kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not use maybe_wrap_array in new functions. It exists in map only for legacy reasons, but we'd really like to put that sort of functionality into a separate dedicated function (e.g., see the proposal for apply_raw in #1618).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem for me. I copied the existing behavior without actually really paying attention as to what it is doing an why.

variables[k] = v

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs:
keep_attrs = 0

if keep_attrs is not False:
attrs = datasets[keep_attrs].attrs
else:
attrs = None

return Dataset(variables, attrs=attrs)
2 changes: 1 addition & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def multiindex_from_product_levels(


def maybe_wrap_array(original, new_array):
"""Wrap a transformed array with __array_wrap__ is it can be done safely.
"""Wrap a transformed array with __array_wrap__ if it can be done safely.

This lets us treat arbitrary functions that take and return ndarray objects
like ufuncs, as long as they return an array with the same shape.
Expand Down
63 changes: 63 additions & 0 deletions xarray/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np

import xarray as xr
from xarray.testing import assert_identical, assert_allclose


class TestMap:
def test_2_dim(self):
da = xr.DataArray(np.random.randn(2, 3))
ds1 = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
ds2 = xr.Dataset({"foo": da + 1, "bar": ("x", [0, 3])})
print(ds1)
print(ds2)

f = lambda a, b: b - a
r = xr.map([ds1, ds2], f)
assert_allclose(r, xr.ones_like(ds1))

def test_different_variables_with_overlap(self):
ds1 = xr.Dataset({"foo": ("x", [1, 2]), "bar": ("x", [-1, 2]),
"oof" : ("x", [3, 4])})
ds2 = xr.Dataset({"foo": ("x", [11, 22]), "oof": ("x", [-1, 3])})
ds3 = xr.Dataset({"bar": ("x", [11, 22]), "oof": ("x", [-1, 2])})

ds_out = xr.Dataset({"oof": ("x", [3, 5])})
f = lambda x, y, z: x + y - z
r = xr.map([ds1, ds2, ds3], f)
assert_identical(r, ds_out)

def test_no_variable_overlap(self):
ds1 = xr.Dataset({"foo": ("x", [1, 2]), "oof": ("x", [3, 4])})
ds2 = xr.Dataset({"bar": ("x", [11, 22]), "rab": ("x", [-1, 3])})

ds_out = xr.Dataset()
f = lambda x, y: x + y
r = xr.map([ds1, ds2], f)
assert_identical(r, ds_out)

def test_with_args_and_kwargs(self):
ds1 = xr.Dataset({"foo": ("x", [1, 1]), "oof": ("x", [3, 3])})
ds2 = xr.Dataset({"foo": ("x", [2, 2]), "oof": ("x", [4, 4])})

ds_out = xr.Dataset({"foo": ("x", [8, 8]), "oof": ("x", [18, 18])})

def f(da1, da2, multiplier_1, multiplier_2):
return multiplier_1 * da1 + multiplier_2 * da2

r = xr.map([ds1, ds2], f, args=[2], kwargs={'multiplier_2' : 3})
assert_identical(r, ds_out)

def test_keep_attrs(self):
ds1 = xr.Dataset({"foo": ("x", [1, 1]), "oof": ("x", [3, 3])}, attrs={'value': 1})
ds2 = xr.Dataset({"foo": ("x", [2, 2]), "oof": ("x", [4, 4])}, attrs={'value': 2})
ds3 = xr.Dataset({"foo": ("x", [3, 3]), "oof": ("x", [5, 5])}, attrs={'value': 3})

def f(da1, da2, da3, selector):
return [da1, da2, da3][selector]

r = xr.map([ds1, ds2, ds3], f, args=[2], keep_attrs=0)
assert r.attrs['value'] == 1

r = xr.map([ds1, ds2, ds3], f, args=[0], keep_attrs=1)
assert r.attrs['value'] == 2