Skip to content

Commit

Permalink
dask support
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 28, 2024
1 parent 275f574 commit 5564c4f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
13 changes: 5 additions & 8 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ def _pick_second(*x):

first = Aggregation("first", chunk=None, combine=None, fill_value=0)
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine=xrutils.nanfirst, fill_value=np.nan)
nanlast = Aggregation("nanlast", chunk="nanlast", combine=xrutils.nanlast, fill_value=np.nan)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -577,8 +577,8 @@ def topk_new_dims_func(k) -> tuple[Dim]:
topk = Aggregation(
name="topk",
fill_value=dtypes.NINF,
chunk=None,
combine=None,
chunk="topk",
combine=xrutils.topk,
final_dtype=None,
new_dims_func=topk_new_dims_func,
)
Expand Down Expand Up @@ -881,10 +881,7 @@ def _initialize_aggregation(
simple_combine: list[Callable | None] = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
simple_combine.append(getattr(xrutils, combine))
else:
simple_combine.append(getattr(np, combine))
simple_combine.append(getattr(np, combine))
else:
simple_combine.append(combine)

Expand Down
18 changes: 18 additions & 0 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,21 @@ def nanlast(values, axis, keepdims=False):
return np.expand_dims(result, axis=axis)
else:
return result


def topk(a, k, axis, keepdims):
"""Chunk and combine function of topk
Extract the k largest elements from a on the given axis.
If k is negative, extract the -k smallest elements instead.
Note that, unlike in the parent function, the returned elements
are not sorted internally.
"""
assert keepdims is True
axis = axis[0]
if abs(k) >= a.shape[axis]:
return a

a = np.partition(a, -k, axis=axis)
k_slice = slice(-k, None) if k > 0 else slice(-k)
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]

0 comments on commit 5564c4f

Please sign in to comment.