Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow rank to run on dask arrays #8475

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,7 @@ def rank(self, dim, pct=False):
--------
Dataset.rank, DataArray.rank
"""
# This could / should arguably be implemented at the DataArray & Dataset level
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"rank requires bottleneck to be enabled."
Expand All @@ -2071,24 +2072,20 @@ def rank(self, dim, pct=False):

import bottleneck as bn

data = self.data

if is_duck_dask_array(data):
raise TypeError(
"rank does not work for arrays stored as dask "
"arrays. Load the data via .compute() or .load() "
"prior to calling this method."
)
elif not isinstance(data, np.ndarray):
raise TypeError(f"rank is not implemented for {type(data)} objects.")

axis = self.get_axis_num(dim)
func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata
ranked = func(data, axis=axis)
ranked = xr.apply_ufunc(
func,
self,
input_core_dims=[[dim]],
output_core_dims=[[dim]],
dask="parallelized",
kwargs=dict(axis=-1),
).transpose(*self.dims)

if pct:
count = np.sum(~np.isnan(data), axis=axis, keepdims=True)
count = self.notnull().sum(dim)
ranked /= count
return Variable(self.dims, ranked)
return ranked

def rolling_window(
self, dim, window, window_dim, center=False, fill_value=dtypes.NA
Expand Down
20 changes: 16 additions & 4 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,9 +1878,20 @@ def test_quantile_out_of_bounds(self, q):

@requires_dask
@requires_bottleneck
def test_rank_dask_raises(self):
v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2)
with pytest.raises(TypeError, match=r"arrays stored as dask"):
def test_rank_dask(self):
# Instead of a single test here, we could parameterize the other tests for both
# arrays. But this is sufficient.
v = Variable(
["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]]
).chunk(x=1)
expected = Variable(
["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]]
)
assert_equal(v.rank("y").compute(), expected)

with pytest.raises(
ValueError, match=r" with dask='parallelized' consists of multiple chunks"
):
v.rank("x")

def test_rank_use_bottleneck(self):
Expand Down Expand Up @@ -1912,7 +1923,8 @@ def test_rank(self):
v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0])
assert_equal(v.rank("x", pct=True), v_expect)
# invalid dim
with pytest.raises(ValueError, match=r"not found"):
with pytest.raises(ValueError):
# apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple`
v.rank("y")

def test_big_endian_reduce(self):
Expand Down
Loading