Skip to content

Commit

Permalink
Support complex arrays in xr.corr (#7392)
Browse files Browse the repository at this point in the history
* complex cov

* fix mypy

* update whatsa-new

* Update xarray/core/computation.py

* slight improvements to tests

* bugfix in corr_cov for multiple dims

* fix whats-new

* allow refreshing of backends

* Revert "allow refreshing of backends"

This reverts commit 576692b.

---------

Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
headtr1ck and dcherian authored Feb 14, 2023
1 parent cd90184 commit 21d8645
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 60 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ v2023.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
46 changes: 30 additions & 16 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@
import warnings
from collections import Counter
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, AbstractSet, Any, Callable, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)

import numpy as np

Expand All @@ -21,7 +30,7 @@
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import T_DataArray
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar
from xarray.core.variable import Variable

Expand Down Expand Up @@ -1209,7 +1218,9 @@ def apply_ufunc(
return apply_array_ufunc(func, *args, dask=dask)


def cov(da_a, da_b, dim=None, ddof=1):
def cov(
da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
) -> T_DataArray:
"""
Compute covariance between two DataArray objects along a shared dimension.
Expand All @@ -1219,9 +1230,9 @@ def cov(da_a, da_b, dim=None, ddof=1):
Array to compute.
da_b : DataArray
Array to compute.
dim : str, optional
dim : str, iterable of hashable, "..." or None, optional
The dimension along which the covariance will be computed
ddof : int, optional
ddof : int, default: 1
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
else normalization is by N.
Expand Down Expand Up @@ -1289,7 +1300,7 @@ def cov(da_a, da_b, dim=None, ddof=1):
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")


def corr(da_a, da_b, dim=None):
def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
"""
Compute the Pearson correlation coefficient between
two DataArray objects along a shared dimension.
Expand All @@ -1300,7 +1311,7 @@ def corr(da_a, da_b, dim=None):
Array to compute.
da_b : DataArray
Array to compute.
dim : str, optional
dim : str, iterable of hashable, "..." or None, optional
The dimension along which the correlation will be computed
Returns
Expand Down Expand Up @@ -1368,7 +1379,11 @@ def corr(da_a, da_b, dim=None):


def _cov_corr(
da_a: T_DataArray, da_b: T_DataArray, dim=None, ddof=0, method=None
da_a: T_DataArray,
da_b: T_DataArray,
dim: Dims = None,
ddof: int = 0,
method: Literal["cov", "corr", None] = None,
) -> T_DataArray:
"""
Internal method for xr.cov() and xr.corr() so only have to
Expand All @@ -1388,22 +1403,21 @@ def _cov_corr(
demeaned_da_b = da_b - da_b.mean(dim=dim)

# 4. Compute covariance along the given dim
#
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / (
valid_count
)
cov = (demeaned_da_a.conj() * demeaned_da_b).sum(
dim=dim, skipna=True, min_count=1
) / (valid_count)

if method == "cov":
return cov
return cov # type: ignore[return-value]

else:
# compute std + corr
da_a_std = da_a.std(dim=dim)
da_b_std = da_b.std(dim=dim)
corr = cov / (da_a_std * da_b_std)
return corr
return corr # type: ignore[return-value]


def cross(
Expand Down Expand Up @@ -1616,7 +1630,7 @@ def cross(

def dot(
*arrays,
dims: str | Iterable[Hashable] | ellipsis | None = None,
dims: Dims = None,
**kwargs: Any,
):
"""Generalized dot product for xarray objects. Like np.einsum, but
Expand All @@ -1626,7 +1640,7 @@ def dot(
----------
*arrays : DataArray or Variable
Arrays to compute.
dims : ..., str or tuple of str, optional
dims : str, iterable of hashable, "..." or None, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs : dict
Expand Down
94 changes: 50 additions & 44 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,25 +1387,29 @@ def test_vectorize_exclude_dims_dask() -> None:

def test_corr_only_dataarray() -> None:
with pytest.raises(TypeError, match="Only xr.DataArray is supported"):
xr.corr(xr.Dataset(), xr.Dataset())
xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var]


def arrays_w_tuples():
@pytest.fixture(scope="module")
def arrays():
da = xr.DataArray(
np.random.random((3, 21, 4)),
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
dims=("a", "time", "x"),
)

arrays = [
return [
da.isel(time=range(0, 18)),
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(),
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]),
]

array_tuples = [

@pytest.fixture(scope="module")
def array_tuples(arrays):
return [
(arrays[0], arrays[0]),
(arrays[0], arrays[1]),
(arrays[1], arrays[1]),
Expand All @@ -1417,27 +1421,19 @@ def arrays_w_tuples():
(arrays[4], arrays[4]),
]

return arrays, array_tuples


@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize(
"da_a, da_b",
[
arrays_w_tuples()[1][3],
arrays_w_tuples()[1][4],
arrays_w_tuples()[1][5],
arrays_w_tuples()[1][6],
arrays_w_tuples()[1][7],
arrays_w_tuples()[1][8],
],
)
@pytest.mark.parametrize("n", [3, 4, 5, 6, 7, 8])
@pytest.mark.parametrize("dim", [None, "x", "time"])
@requires_dask
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
def test_lazy_corrcov(
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
) -> None:
# GH 5284
from dask import is_dask_collection

da_a, da_b = array_tuples[n]

with raise_if_dask_computes():
cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof)
assert is_dask_collection(cov)
Expand All @@ -1447,12 +1443,13 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:


@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize(
"da_a, da_b",
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
)
@pytest.mark.parametrize("n", [0, 1, 2])
@pytest.mark.parametrize("dim", [None, "time"])
def test_cov(da_a, da_b, dim, ddof) -> None:
def test_cov(
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
) -> None:
da_a, da_b = array_tuples[n]

if dim is not None:

def np_cov_ind(ts1, ts2, a, x):
Expand Down Expand Up @@ -1499,12 +1496,13 @@ def np_cov(ts1, ts2):
assert_allclose(actual, expected)


@pytest.mark.parametrize(
"da_a, da_b",
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
)
@pytest.mark.parametrize("n", [0, 1, 2])
@pytest.mark.parametrize("dim", [None, "time"])
def test_corr(da_a, da_b, dim) -> None:
def test_corr(
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
) -> None:
da_a, da_b = array_tuples[n]

if dim is not None:

def np_corr_ind(ts1, ts2, a, x):
Expand Down Expand Up @@ -1547,12 +1545,12 @@ def np_corr(ts1, ts2):
assert_allclose(actual, expected)


@pytest.mark.parametrize(
"da_a, da_b",
arrays_w_tuples()[1],
)
@pytest.mark.parametrize("n", range(9))
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_covcorr_consistency(da_a, da_b, dim) -> None:
def test_covcorr_consistency(
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
) -> None:
da_a, da_b = array_tuples[n]
# Testing that xr.corr and xr.cov are consistent with each other
# 1. Broadcast the two arrays
da_a, da_b = broadcast(da_a, da_b)
Expand All @@ -1569,10 +1567,13 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:


@requires_dask
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
@pytest.mark.parametrize("n", range(9))
@pytest.mark.parametrize("dim", [None, "time", "x"])
@pytest.mark.filterwarnings("ignore:invalid value encountered in .*divide")
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
def test_corr_lazycorr_consistency(
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
) -> None:
da_a, da_b = array_tuples[n]
da_al = da_a.chunk()
da_bl = da_b.chunk()
c_abl = xr.corr(da_al, da_bl, dim=dim)
Expand All @@ -1591,22 +1592,27 @@ def test_corr_dtype_error():
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))


@pytest.mark.parametrize(
"da_a",
arrays_w_tuples()[0],
)
@pytest.mark.parametrize("n", range(5))
@pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]])
def test_autocov(da_a, dim) -> None:
def test_autocov(n: int, dim: str | None, arrays) -> None:
da = arrays[n]

# Testing that the autocovariance*(N-1) is ~=~ to the variance matrix
# 1. Ignore the nans
valid_values = da_a.notnull()
valid_values = da.notnull()
# Because we're using ddof=1, this requires > 1 value in each sample
da_a = da_a.where(valid_values.sum(dim=dim) > 1)
expected = ((da_a - da_a.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
actual = xr.cov(da_a, da_a, dim=dim) * (valid_values.sum(dim) - 1)
da = da.where(valid_values.sum(dim=dim) > 1)
expected = ((da - da.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
actual = xr.cov(da, da, dim=dim) * (valid_values.sum(dim) - 1)
assert_allclose(actual, expected)


def test_complex_cov() -> None:
da = xr.DataArray([1j, -1j])
actual = xr.cov(da, da)
assert abs(actual.item()) == 2


@requires_dask
def test_vectorize_dask_new_output_dims() -> None:
# regression test for GH3574
Expand Down

0 comments on commit 21d8645

Please sign in to comment.