Skip to content

Commit

Permalink
Fix mypy errors in xarray.py, xrutils.py, cache.py (#144)
Browse files Browse the repository at this point in the history
* update dim typing

* Fix mypy errors in xarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* start mypy ci

* Use T_DataArray and T_Dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add mypy ignores

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* correct typing a bit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test newer flake8 if ellipsis passes there

* Allow ellipsis in flake8

* Update core.py

* Update xarray.py

* Update setup.cfg

* Update xarray.py

* Update xarray.py

* Update xarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update xarray.py

* Update pyproject.toml

* Update xarray.py

* Update xarray.py

* hopefully no more pytest errors.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make sure expected_groups doesn't have None

* Update flox/xarray.py

Co-authored-by: Deepak Cherian <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ds_broad and longer comment

* Use same for loop for similar things.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix xrutils.py

* fix errors in cache.py

* Turn off mypy check

* Update flox/xarray.py

Co-authored-by: Deepak Cherian <[email protected]>

* Update flox/xarray.py

Co-authored-by: Deepak Cherian <[email protected]>

* Use if else format to avoid tuple creation

* Update xarray.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2022
1 parent af3e3ce commit 2b54c5e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 70 deletions.
2 changes: 1 addition & 1 deletion flox/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
cache = cachey.Cache(1e6)
memoize = partial(cache.memoize, key=dask.base.tokenize)
except ImportError:
memoize = lambda x: x
memoize = lambda x: x # type: ignore
15 changes: 12 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import operator
from collections import namedtuple
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Sequence,
Union,
)

import numpy as np
import numpy_groupies as npg
Expand Down Expand Up @@ -1282,8 +1291,8 @@ def _assert_by_is_aligned(shape, by):


def _convert_expected_groups_to_index(
expected_groups: tuple, isbin: bool, sort: bool
) -> pd.Index | None:
expected_groups: Iterable, isbin: Sequence[bool], sort: bool
) -> tuple[pd.Index | None]:
out = []
for ex, isbin_ in zip(expected_groups, isbin):
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
Expand Down
151 changes: 86 additions & 65 deletions flox/xarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Hashable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union

import numpy as np
import pandas as pd
Expand All @@ -19,7 +19,10 @@
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric

if TYPE_CHECKING:
from xarray import DataArray, Dataset, Resample
from xarray.core.resample import Resample
from xarray.core.types import T_DataArray, T_Dataset

Dims = Union[str, Iterable[Hashable], None]


def _get_input_core_dims(group_names, dim, ds, grouper_dims):
Expand Down Expand Up @@ -51,13 +54,13 @@ def lookup_order(dimension):


def xarray_reduce(
obj: Dataset | DataArray,
*by: DataArray | Iterable[str] | Iterable[DataArray],
obj: T_Dataset | T_DataArray,
*by: T_DataArray | Hashable,
func: str | Aggregation,
expected_groups=None,
isbin: bool | Sequence[bool] = False,
sort: bool = True,
dim: Hashable = None,
dim: Dims | ellipsis = None,
split_out: int = 1,
fill_value=None,
method: str = "map-reduce",
Expand Down Expand Up @@ -203,8 +206,11 @@ def xarray_reduce(
if keep_attrs is None:
keep_attrs = True

if isinstance(isbin, bool):
isbin = (isbin,) * nby
if isinstance(isbin, Sequence):
isbins = isbin
else:
isbins = (isbin,) * nby

if expected_groups is None:
expected_groups = (None,) * nby
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
Expand All @@ -217,78 +223,86 @@ def xarray_reduce(
raise NotImplementedError

# eventually drop the variables we are grouping by
maybe_drop = [b for b in by if isinstance(b, str)]
maybe_drop = [b for b in by if isinstance(b, Hashable)]
unindexed_dims = tuple(
b
for b, isbin_ in zip(by, isbin)
if isinstance(b, str) and not isbin_ and b in obj.dims and b not in obj.indexes
for b, isbin_ in zip(by, isbins)
if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes
)

by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by) # type: ignore
by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by)

grouper_dims = []
for g in by:
for g in by_da:
for d in g.dims:
if d not in grouper_dims:
grouper_dims.append(d)

if isinstance(obj, xr.DataArray):
ds = obj._to_temp_dataset()
else:
if isinstance(obj, xr.Dataset):
ds = obj
else:
ds = obj._to_temp_dataset()

ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])

if dim is Ellipsis:
if nby > 1:
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
dim = tuple(obj.dims)
if by[0].name in ds.dims and not isbin[0]:
dim = tuple(d for d in dim if d != by[0].name)
name_ = by_da[0].name
if name_ in ds.dims and not isbins[0]:
dim_tuple = tuple(d for d in obj.dims if d != name_)
else:
dim_tuple = tuple(obj.dims)
elif dim is not None:
dim = _atleast_1d(dim)
dim_tuple = _atleast_1d(dim)
else:
dim = tuple()
dim_tuple = tuple()

# broadcast all variables against each other along all dimensions in `by` variables
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
# then we also broadcast `by` to all `obj.dims`
# TODO: avoid this broadcasting
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim)
ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims)
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)

if not dim:
dim = tuple(by[0].dims)
# all members of by_broad have the same dimensions
# so we just pull by_broad[0].dims if dim is None
if not dim_tuple:
dim_tuple = tuple(by_broad[0].dims)

if any(d not in grouper_dims and d not in obj.dims for d in dim):
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")

dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims)
if dims_not_in_groupers == tuple(dim) and not any(isbin):
dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims)
if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins):
# reducing along a dimension along which groups do not vary
# This is really just a normal reduction.
# This is not right when binning so we exclude.
if skipna and isinstance(func, str):
dsfunc = func[3:]
if isinstance(func, str):
dsfunc = func[3:] if skipna else func
else:
dsfunc = func
raise NotImplementedError(
"func must be a string when reducing along a dimension not present in `by`"
)
# TODO: skipna needs test
result = getattr(ds, dsfunc)(dim=dim, skipna=skipna)
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
if isinstance(obj, xr.DataArray):
return obj._from_temp_dataset(result)
else:
return result

axis = tuple(range(-len(dim), 0))
group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin))

group_shape = [None] * len(by)
expected_groups = list(expected_groups)
axis = tuple(range(-len(dim_tuple), 0))

# Set expected_groups and convert to index since we need coords, sizes
# for output xarray objects
for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)):
expected_groups = list(expected_groups)
group_names: tuple[Any, ...] = ()
group_sizes: dict[Any, int] = {}
for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
group_name = b_.name if not isbin_ else f"{b_.name}_bins"
group_names += (group_name,)

if isbin_ and isinstance(expect, int):
raise NotImplementedError(
"flox does not support binning into an integer number of bins yet."
Expand All @@ -297,13 +311,21 @@ def xarray_reduce(
if isbin_:
raise ValueError(
f"Please provided bin edges for group variable {idx} "
f"named {group_names[idx]} in expected_groups."
f"named {group_name} in expected_groups."
)
expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True)

expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort)
group_shape = tuple(len(e) for e in expected_groups)
group_sizes = dict(zip(group_names, group_shape))
expect_ = _get_expected_groups(b_.data, sort=sort, raise_if_dask=True)
else:
expect_ = expect
expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0]

# The if-check is for type hinting mainly, it narrows down the return
# type of _convert_expected_groups_to_index to pure pd.Index:
if expect_index is not None:
expected_groups[idx] = expect_index
group_sizes[group_name] = len(expect_index)
else:
# This will never be reached
raise ValueError("expect_index cannot be None")

def wrapper(array, *by, func, skipna, **kwargs):
# Handle skipna here because I need to know dtype to make a good default choice.
Expand Down Expand Up @@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs):
if isinstance(obj, xr.Dataset):
# broadcasting means the group dim gets added to ds, so we check the original obj
for k, v in obj.data_vars.items():
is_missing_dim = not (any(d in v.dims for d in dim))
is_missing_dim = not (any(d in v.dims for d in dim_tuple))
if is_missing_dim:
missing_dim[k] = v

input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
input_core_dims += [input_core_dims[-1]] * (nby - 1)

actual = xr.apply_ufunc(
wrapper,
ds.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
*by,
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
*by_broad,
input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
exclude_dims=set(dim),
exclude_dims=set(dim_tuple),
output_core_dims=[group_names],
dask="allowed",
dask_gufunc_kwargs=dict(output_sizes=group_sizes),
Expand All @@ -379,27 +401,27 @@ def wrapper(array, *by, func, skipna, **kwargs):
"engine": engine,
"reindex": reindex,
"expected_groups": tuple(expected_groups),
"isbin": isbin,
"isbin": isbins,
"finalize_kwargs": finalize_kwargs,
},
)

# restore non-dim coord variables without the core dimension
# TODO: shouldn't apply_ufunc handle this?
for var in set(ds.variables) - set(ds.dims):
if all(d not in ds[var].dims for d in dim):
actual[var] = ds[var]
for var in set(ds_broad.variables) - set(ds_broad.dims):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

for name, expect, by_ in zip(group_names, expected_groups, by):
for name, expect, by_ in zip(group_names, expected_groups, by_broad):
# Can't remove this till xarray handles IntervalIndex
if isinstance(expect, pd.IntervalIndex):
expect = expect.to_numpy()
if isinstance(actual, xr.Dataset) and name in actual:
actual = actual.drop_vars(name)
# When grouping by MultiIndex, expect is an pd.Index wrapping
# an object array of tuples
if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex):
levelnames = ds.indexes[name].names
if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex):
levelnames = ds_broad.indexes[name].names
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
actual[name] = expect
if Version(xr.__version__) > Version("2022.03.0"):
Expand All @@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs):

if nby == 1:
for var in actual:
if isinstance(obj, xr.DataArray):
template = obj
else:
if isinstance(obj, xr.Dataset):
template = obj[var]
else:
template = obj

if actual[var].ndim > 1:
actual[var] = _restore_dim_order(actual[var], template, by[0])
actual[var] = _restore_dim_order(actual[var], template, by_broad[0])

if missing_dim:
for k, v in missing_dim.items():
missing_group_dims = {
dim: size for dim, size in group_sizes.items() if dim not in v.dims
}
missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims}
# The expand_dims is for backward compat with xarray's questionable behaviour
if missing_group_dims:
actual[k] = v.expand_dims(missing_group_dims).variable
Expand All @@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs):


def rechunk_for_cohorts(
obj: DataArray | Dataset,
obj: T_DataArray | T_Dataset,
dim: str,
labels: DataArray,
labels: T_DataArray,
force_new_chunk_at,
chunksize: int | None = None,
ignore_old_chunks: bool = False,
Expand Down Expand Up @@ -486,7 +507,7 @@ def rechunk_for_cohorts(
)


def rechunk_for_blockwise(obj: DataArray | Dataset, dim: str, labels: DataArray):
def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray):
"""
Rechunks array so that group boundaries line up with chunk boundaries, allowing
embarassingly parallel group reductions.
Expand Down
2 changes: 1 addition & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

dask_array_type = dask.array.Array
except ImportError:
dask_array_type = ()
dask_array_type = () # type: ignore


def asarray(data, xp=np):
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,5 @@ per-file-ignores =
exclude=
.eggs
doc
builtins =
ellipsis

0 comments on commit 2b54c5e

Please sign in to comment.