Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interp for interpolating between chunks of data (dask) #4155

Merged
merged 44 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
62c6385
Implement interp for interpolating between chunks of data (dask)
Jun 15, 2020
f6f7dad
do not forget extra points at the end
Jun 15, 2020
b0d8a5f
add tests
Jun 15, 2020
1a31457
add whats-new comment
Jun 15, 2020
9933c73
fix isort / black
Jun 15, 2020
cea826b
typo
Jun 15, 2020
44bbedf
update pull number
Jun 15, 2020
067b7f3
fix github pep8 warnigns
Jun 15, 2020
c47a1d5
fix isort
Jun 15, 2020
7d505a1
clearer arguments in _dask_aware_interpnd
Jul 17, 2020
423b36d
typo
Jul 17, 2020
85ff539
fix for datetimelike index
Jul 20, 2020
6e9b50e
chunked interpolation does not work for high order interpolation (qua…
Jul 20, 2020
c63636f
Merge branch 'upstream' into chunked_interp
Jul 20, 2020
86cb592
fix whats new
Jul 20, 2020
5e26a4e
remove a useless import
Jul 20, 2020
3ca6e6d
use Variable instead of InexVariable
Jul 21, 2020
a131b21
avoid some list to tuple conversion
Jul 21, 2020
67d2b36
black fix
Jul 21, 2020
f485958
more comments to explain _compute_chunks
Jul 21, 2020
42f8a3b
For orthogonal linear- and nearest-neighbor interpolation, the scalar…
Jul 24, 2020
ec3c400
better detection of Advanced interpolation
Jul 24, 2020
e231954
implement support of unsorted interpolation destination
Jul 24, 2020
061f5a8
rework the tests
Jul 24, 2020
623cb0b
fix for datetime index (bug introduced with unsorted destination)
Jul 24, 2020
b66d123
Variable is cheaber that DataArray
Jul 24, 2020
e211127
add warning if unsorted
Jul 27, 2020
e610268
simplify _compute_chunks
Jul 27, 2020
7547d56
add ghosts point in order to make quadratic and cubic method work in…
Jul 27, 2020
fd936dd
black
Jul 27, 2020
24f9460
forgot to remove an exception in test_upsample_interpolate_dask
Jul 27, 2020
dd2f273
fix filtering out-of-order warning
Jul 28, 2020
49bdefa
use extrapolate to check external points
Jul 28, 2020
d280867
Revert "add ghosts point in order to make quadratic and cubic method …
Jul 29, 2020
aeb7be1
Complete rewrite using blockwise
Jul 29, 2020
3c7d8c6
Merge branch 'upstream' into chunked_interp
Jul 29, 2020
0bc35d2
update whats-new.rst
Jul 29, 2020
0d5f618
reduce the diff
Jul 29, 2020
290a075
more decomposition of orthogonal interpolation
Jul 29, 2020
3f8718e
simplify _dask_aware_interpnd a little
Jul 30, 2020
562d5aa
fix dask interp when chunks are not aligned
Jul 30, 2020
62f059c
continue simplifying _dask_aware_interpnd
Jul 31, 2020
3d4d45c
update whats-new.rst
Jul 31, 2020
b60cddf
clean tests
Jul 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Enhancements
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep
coordinate attributes (:pull:`4103`). By `Oriol Abril <https://github.com/OriolAbril>`_.
- :py:meth:`DataArray.interp` now support simple interpolation in a chunked dimension
(but not advanced interpolation) (:pull:`4155`). By `Alexandre Poux <https://github.com/pums974>`_.

New Features
~~~~~~~~~~~~
Expand Down
161 changes: 143 additions & 18 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Hashable, Sequence, Union
from typing import Any, Callable, Dict, Hashable, List, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -544,13 +544,11 @@ def _get_valid_fill_mask(arr, dim, limit):
) <= limit


def _assert_single_chunk(var, axes):
def _single_chunk(var, axes):
for axis in axes:
if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]:
raise NotImplementedError(
"Chunking along the dimension to be interpolated "
"({}) is not yet supported.".format(axis)
)
return False
return True


def _localize(var, indexes_coords):
Expand Down Expand Up @@ -706,22 +704,76 @@ def interp_func(var, x, new_x, method, kwargs):
if isinstance(var, dask_array_type):
import dask.array as da

_assert_single_chunk(var, range(var.ndim - len(x), var.ndim))
chunks = var.chunks[: -len(x)] + new_x[0].shape
drop_axis = range(var.ndim - len(x), var.ndim)
new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim)
return da.map_blocks(
_interpnd,
var,
x,
new_x,
# easyer, and allows advanced interpolation
if _single_chunk(var, range(var.ndim - len(x), var.ndim)):
chunks = var.chunks[: -len(x)] + new_x[0].shape
drop_axis = range(var.ndim - len(x), var.ndim)
new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim)
return da.map_blocks(
_interpnd,
var,
x,
new_x,
func,
kwargs,
dtype=var.dtype,
chunks=chunks,
new_axis=new_axis,
drop_axis=drop_axis,
)

current_dims = [_x.name for _x in x]

# number of non interpolated dimensions
nconst = var.ndim - len(x)

# chunks x
x = tuple(
da.from_array(_x, chunks=chunks)
for _x, chunks in zip(x, var.chunks[nconst:])
)

# duplicate the ghost cells of the array in the interpolated dimensions
var_with_ghost, x_with_ghost = _add_interp_ghost(var, x, nconst)

# compute final chunks
target_dims = set.union(*[set(_x.dims) for _x in new_x])
if target_dims - set(current_dims):
raise NotImplementedError(
"Advanced interpolation is not implemented with chunked dimension"
)
new_x = tuple([_x.set_dims(current_dims) for _x in new_x])
total_chunks = _compute_chunks(x, x_with_ghost, new_x)
final_chunks = var.chunks[: -len(x)] + tuple(total_chunks)

# chunks new_x
new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x)

# reshape x_with_ghost
# TODO: remove it (see _dask_aware_interpnd)
x_with_ghost = da.meshgrid(*x_with_ghost, indexing="ij")

# compute on chunks
res = da.map_blocks(
_dask_aware_interpnd,
var_with_ghost,
func,
kwargs,
len(x_with_ghost),
*x_with_ghost,
*new_x,
dtype=var.dtype,
chunks=chunks,
new_axis=new_axis,
drop_axis=drop_axis,
chunks=final_chunks,
)

# reshape res and remove empty chunks
# TODO: remove it by using drop_axis and new_axis in map_blocks
res = res.squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to give the squeezing axis explicitly?
What happens if the original array already has a size-one dimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you're right, it's probably not safe, hence the TODO.
But at the time I didn't manage to use drop_axis and new_axis...
I'll give an another try tomorrow.

new_chunks = tuple(
[tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks]
)
res = res.rechunk(new_chunks)
return res

return _interpnd(var, x, new_x, func, kwargs)

Expand Down Expand Up @@ -751,3 +803,76 @@ def _interpnd(var, x, new_x, func, kwargs):
# move back the interpolation axes to the last position
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape)


def _dask_aware_interpnd(var, func: Callable[..., Any], kwargs: Any, nx: int, *arrs):
"""Wrapper for `_interpnd` allowing dask array to be used in `map_blocks`

The first `nx` arrays in `arrs` are original coordinates, the rest are destination coordinate
Currently this need original coordinate to be full arrays (meshgrid)

TODO: find a way to use 1d coordinates
"""
from .dataarray import DataArray

_old_x, _new_x = arrs[:nx], arrs[nx:]

# reshape x (TODO REMOVE)
old_x = tuple(
[
np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))]
for dim, tmp in enumerate(_old_x)
]
)

new_x = tuple([DataArray(_x) for _x in _new_x])

return _interpnd(var, old_x, new_x, func, kwargs)


def _add_interp_ghost(var, x, nconst: int):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
""" Duplicate the ghost cells of the array (values and coordinates)"""
import dask.array as da

bnd = {i: "none" for i in range(len(var.shape))}
depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))}

var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd)

x_with_ghost = tuple(
da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"}) for _x in x
)
return var_with_ghost, x_with_ghost


def _compute_chunks(x, x_with_ghost, new_x):
"""Compute equilibrated chunks of new_x

TODO: This only works if new_x is a set of 1d coordinate
more general function is needed for advanced interpolation with chunked dimension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more doc for this function? It is difficult to follow the logic for me...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better now ?

"""
chunks_end = [np.cumsum(sizes) - 1 for _x in x for sizes in _x.chunks]
chunks_end_with_ghost = [
np.cumsum(sizes) - 1 for _x in x_with_ghost for sizes in _x.chunks
]
total_chunks = []
for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)):
l_new_x_ends: List[np.ndarray] = []
for iend, iend_with_ghost in zip(*ce):

arr = np.moveaxis(new_x[dim].data, dim, -1)
arr = arr[tuple([0] * (len(arr.shape) - 1))]
pums974 marked this conversation as resolved.
Show resolved Hide resolved

n_no_ghost = (arr <= x[dim][iend]).sum()
n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum()

equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int)

l_new_x_ends.append(equil)

new_x_ends = np.array(l_new_x_ends)
# do not forget extra points at the end
new_x_ends[-1] = len(arr)
chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1])
total_chunks.append(tuple(chunks))
return total_chunks
50 changes: 45 additions & 5 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ def test_interpolate_1d(method, dim, case):
da = get_example_data(case)
xdest = np.linspace(0.0, 0.9, 80)

if dim == "y" and case == 1:
with pytest.raises(NotImplementedError):
actual = da.interp(method=method, **{dim: xdest})
pytest.skip("interpolation along chunked dimension is " "not yet supported")

actual = da.interp(method=method, **{dim: xdest})

# scipy interpolation for the reference
Expand Down Expand Up @@ -717,3 +712,48 @@ def test_decompose(method):
actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y"))
expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y"))
assert_allclose(actual, expected)


def test_interpolate_chunk_1d():
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask:
pytest.skip("dask is not installed in the environment.")

da = get_example_data(1)
ydest = np.linspace(-0.1, 0.2, 80)

actual = da.interp(method="linear", y=ydest)
expected = da.compute().interp(method="linear", y=ydest)

assert_allclose(actual, expected)


@pytest.mark.parametrize("scalar_nx", [True, False])
def test_interpolate_chunk_nd(scalar_nx):
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask:
pytest.skip("dask is not installed in the environment.")

da = get_example_data(1).chunk({"x": 50})

if scalar_nx:
# 0.5 is between chunks
xdest = 0.5
else:
# -0.5 is before data
# 0.5 is between chunks
# 1.5 is after data
xdest = [-0.5, 0.25, 0.5, 0.75, 1.5]
# -0.1 is before data
# 0.05 is between chunks
# 0.15 is after data
ydest = [-0.1, 0.025, 0.05, 0.075, 0.15]

actual = da.interp(method="linear", x=xdest, y=ydest)
expected = da.compute().interp(method="linear", x=xdest, y=ydest)

assert_allclose(actual, expected)