Skip to content

Commit

Permalink
Switch to TimeResampler objects
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 22, 2024
1 parent acfca63 commit 5ffb9e4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 22 deletions.
10 changes: 10 additions & 0 deletions doc/user-guide/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,16 @@ larger chunksizes.

Check out the `dask documentation on chunks <https://docs.dask.org/en/latest/array-chunks.html>`_.

.. tip::

Many time domain problems become amenable to an embarassingly parallel or blockwise solution
(e.g. using :py:func:`xarray.map_blocks`, :py:func:`dask.array.map_blocks`, or
:py:func:`dask.array.blockwise`) by rechunking to a frequency along the time dimension.
Provide :py:class:`groupers.TimeResampler` objects to :py:meth:`Dataset.chunk` to do so.
For example ``ds.chunk(time=TimeResampler("MS"))`` will set the chunks so that a month of
data is contained in one chunk. The resulting chunk sizes need not be uniform, depending on
the frequency of the data, and the calendar.


Optimization Tips
-----------------
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,29 +1400,30 @@ def chunk(
xarray.unify_chunks
dask.array.from_array
"""
chunk_mapping: T_ChunksFreq
if chunks is None:
warnings.warn(
"None value for 'chunks' is deprecated. "
"It will raise an error in the future. Use instead '{}'",
category=FutureWarning,
)
chunks = {}
chunk_mapping = {}

if isinstance(chunks, (float, str, int)):
# ignoring type; unclear why it won't accept a Literal into the value.
chunks = dict.fromkeys(self.dims, chunks)
chunk_mapping = dict.fromkeys(self.dims, chunks)
elif isinstance(chunks, (tuple, list)):
utils.emit_user_level_warning(
"Supplying chunks as dimension-order tuples is deprecated. "
"It will raise an error in the future. Instead use a dict with dimension names as keys.",
category=DeprecationWarning,
)
chunks = dict(zip(self.dims, chunks))
chunk_mapping = dict(zip(self.dims, chunks))
else:
chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
chunk_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

ds = self._to_temp_dataset().chunk(
chunks,
chunk_mapping,
name_prefix=name_prefix,
token=token,
lock=lock,
Expand Down
16 changes: 9 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,6 +2708,7 @@ def chunk(
dask.array.from_array
"""
from xarray.core.dataarray import DataArray
from xarray.core.groupers import TimeResampler

if chunks is None and not chunks_kwargs:
warnings.warn(
Expand All @@ -2734,27 +2735,28 @@ def chunk(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

def _resolve_frequency(name: Hashable, freq: str) -> tuple[int, ...]:
def _resolve_frequency(
name: Hashable, resampler: TimeResampler
) -> tuple[int, ...]:
variable = self._variables.get(name, None)
if variable is None:
raise ValueError(
f"Cannot chunk by frequency string {freq!r} for virtual variables."
f"Cannot chunk by resampler {resampler!r} for virtual variables."
)
elif not _contains_datetime_like_objects(variable):
raise ValueError(
f"chunks={freq!r} only supported for datetime variables. "
f"chunks={resampler!r} only supported for datetime variables. "
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
)

assert variable.ndim == 1
chunks: tuple[int, ...] = tuple(
DataArray(
np.ones(variable.shape, dtype=int),
dims=(name,),
coords={name: variable},
)
# TODO: This could be generalized to `freq` being a `Resampler` object,
# and using `groupby` instead of `resample`
.resample({name: freq})
.resample({name: resampler})
.sum()
.data.tolist()
)
Expand All @@ -2763,7 +2765,7 @@ def _resolve_frequency(name: Hashable, freq: str) -> tuple[int, ...]:
chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
name: (
_resolve_frequency(name, chunks)
if isinstance(chunks, str) and chunks != "auto"
if isinstance(chunks, TimeResampler)
else chunks
)
for name, chunks in chunks_mapping.items()
Expand Down
9 changes: 8 additions & 1 deletion xarray/core/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
)


@dataclass
@dataclass(repr=False)
class TimeResampler(Resampler):
"""
Grouper object specialized to resampling the time coordinate.
Expand Down Expand Up @@ -297,6 +297,13 @@ class TimeResampler(Resampler):
index_grouper: CFTimeGrouper | pd.Grouper = field(init=False)
group_as_index: pd.Index = field(init=False)

def __repr__(self):
return (
f"<TimeResampler freq={self.freq!r}, closed={self.closed!r}, "
f"label={self.label!r}, origin={self.origin!r}, "
f"offset={self.offset!r}, loffset={self.loffset!r}, base={self.base!r}>"
)

def __post_init__(self):
if self.loffset is not None:
emit_user_level_warning(
Expand Down
7 changes: 3 additions & 4 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.groupers import TimeResampler
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
Expand Down Expand Up @@ -190,10 +191,8 @@ def copy(
# FYI in some cases we don't allow `None`, which this doesn't take account of.
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]]
T_FreqStr: TypeAlias = str
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
T_ChunkDimFreq: TypeAlias = Union[T_FreqStr, T_ChunkDim]
T_ChunksFreq: TypeAlias = Union[T_ChunkDimFreq, Mapping[Any, T_ChunkDimFreq]]
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
T_ChunksFreq: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDimFreq]]
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]]
T_NormalizedChunks = tuple[tuple[int, ...], ...]
Expand Down
10 changes: 5 additions & 5 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
except ImportError:
from numpy import RankWarning

import xarray as xr
from xarray import (
DataArray,
Expand All @@ -38,6 +37,7 @@
from xarray.core import dtypes, indexing, utils
from xarray.core.common import duck_array_ops, full_like
from xarray.core.coordinates import Coordinates, DatasetCoordinates
from xarray.core.groupers import TimeResampler
from xarray.core.indexes import Index, PandasIndex
from xarray.core.utils import is_scalar
from xarray.namedarray.pycompat import array_type, integer_types
Expand Down Expand Up @@ -1219,20 +1219,20 @@ def test_chunk_by_frequency(self, freq, calendar) -> None:
)
},
)
actual = ds.chunk(time=freq).chunksizes["time"]
actual = ds.chunk(time=TimeResampler(freq)).chunksizes["time"]
expected = tuple(ds.ones.resample(time=freq).sum().data.tolist())
assert expected == actual

def test_chunk_by_frequecy_errors(self):
ds = Dataset({"foo": ("x", [1, 2, 3])})
with pytest.raises(ValueError, match="virtual variable"):
ds.chunk(x="YE")
ds.chunk(x=TimeResampler("YE"))
ds["x"] = ("x", [1, 2, 3])
with pytest.raises(ValueError, match="datetime variables"):
ds.chunk(x="YE")
ds.chunk(x=TimeResampler("YE"))
ds["x"] = ("x", xr.date_range("2001-01-01", periods=3, freq="D"))
with pytest.raises(ValueError, match="Invalid frequency"):
ds.chunk(x="foo")
ds.chunk(x=TimeResampler("foo"))

@requires_dask
def test_dask_is_lazy(self) -> None:
Expand Down

0 comments on commit 5ffb9e4

Please sign in to comment.