Skip to content

Commit

Permalink
cross post dask detrending from xarray blog
Browse files Browse the repository at this point in the history
  • Loading branch information
scharlottej13 committed Nov 20, 2024
1 parent 0d3d6e1 commit 2ca3fd4
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions _posts/2024-11-21-dask-detrending.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
---
layout: post
title: Improving GroupBy.map with Dask and Xarray
author: Patrick Hoefler
tags: [dask array, xarray]
theme: twitter
canonical_url: https://xarray.dev/blog/dask-detrending
---

{% include JB/setup %}

Running large-scale GroupBy-Map patterns with Xarray that are backed by [Dask arrays](https://docs.dask.org/en/stable/array.html?utm_source=dask-blog) is
an essential part of a lot of typical geospatial workloads. Detrending is a very common
operation where this pattern is needed.

In this post, we will explore how and why this caused so many pitfalls for Xarray users in
the past and how we improved performance and scalability with a few changes to how Dask
subselects data.

## What is GroupBy.map?

[`GroupBy.map`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.map.html) lets you apply a User Defined Function (UDF)
that accepts and returns Xarray objects. The UDF will receive an Xarray object (either a Dataset or a DataArray) containing Dask arrays corresponding to one single group.
[`Groupby.reduce`](https://docs.xarray.dev/en/stable/generated/xarray.core.groupby.DatasetGroupBy.reduce.html) is quite similar
in that it applies a UDF, but in this case the UDF will receive the underlying Dask arrays, _not_ Xarray objects.

## The Application

Consider a typical workflow where you want to apply a detrending step. You may want to smooth out
the data by removing the trends over time. This is a common operation in climate science
and normally looks roughly like this:

```python
def detrending_step(arr: DataArray) -> DataArray:
# important: the rolling operation is applied within a group
return arr - arr.rolling(time=30, min_periods=1).mean()

data.groupby("time.dayofyear").map(detrending_step)
```

We are grouping by the day of the year and then are calculating the rolling average over
30-year windows for a particular day.

Our example will run on a 1 TiB array, 64 years worth of data and the following structure:

![Python repr output of 1 TiB Dask array with shape (1801, 3600, 233376) split into 5460, 250 MiB chunks of (300, 300, 365) ](/images/dask-detrending/input-array.png)

The array isn't overly huge and the chunks are reasonably sized.

## The Problem

The general application seems straightforward. Group by the day of the year and apply a UDF
to every group. There are a few pitfalls in this application that can make the result of
this operation unusable. Our array is sorted by time, which means that we have to pick
entries from many different areas in the array to create a single group (corresponding to a single day of the year).
Picking the same day of every year is basically a slicing operation with a step size of 365.

![Schematic showing an array sorted by time, where data is selected from many different areas in the array to create a single group (corresponding to a specific day of the year).](/images/dask-detrending/indexing-data-selection.png "Data Selection Pattern")

Our example has a year worth of data in a single chunk along the time axis. The general problem
exists for any workload where you have to access random entries of data. This
particular access pattern means that we have to pick one value per chunk, which is pretty
inefficient. The right side shows the individual groups that we are operating on.

One of the main issues with this pattern is that Dask will create a single output chunk per time
entry, e.g. each group will consist of as many chunks as we have year.

This results in a huge increase in the number of chunks:

![Python repr output of a 1 TiB Dask array with nearly 2 million, 700 kiB chunks.](/images/dask-detrending/output-array-old.png)

This simple operation increases the number of chunks from 5000 to close to 2 million. Each
chunk only has a few hundred kilobytes of data. **This is pretty bad!**

Dask computations generally scale along the number of chunks you have. Increasing the chunks by such
a large factor is catastrophic. Each follow-up operation, as simple as `a-b` will create 2 million
additional tasks.

The only workaround for users was to rechunk to something more sensible afterward, but it
still keeps the incredibly expensive indexing operation in the graph.

Note this is the underlying problem that is [solved by flox](https://xarray.dev/blog/flox) for aggregations like `.mean()`
using parallel-native algorithms to avoid the expense of indexing out each group.

## Improvements to the Data Selection algorithm

The method of how Dask selected the data was objectively pretty bad.
A rewrite of the underlying algorithm enabled us to achieve a much more robust result. The new
algorithm is a lot smarter about how to pick values from each individual chunk, but most importantly,
it will try to preserve the input chunksize as closely as possible.

For our initial example, it will put every group into a single chunk. This means that we will
end up with the number of chunks along the time axis being equal to the number of groups, i.e. 365.

![Python repr output of a 1 TiB Dask array with 31164, 43 MiB chunks](/images/dask-detrending/output-array-new.png)

The algorithm reduces the number of chunks from 2 million to roughly 30 thousand, which is a huge improvement
and a scale that Dask can easily handle. The graph is now much smaller, and the follow-up operations
will run a lot faster as well.

This improvement will help every operation that we listed above and make the scale a lot more
reliably than before. The algorithm is used very widely across Dask and Xarray and thus, influences
many methods.

## What's next?

Xarray selects one group at a time for `groupby(...).map(...)`, i.e. this requires one operation
per group. This will hurt scalability if the dataset has a very large number of groups, because
the computation will create a very expensive graph. There is currently an effort to implement alternative
APIs that are shuffle-based to circumvent that problem. A current PR is available [here](https://github.com/pydata/xarray/pull/9320).

The fragmentation of the output chunks by indexing is something that will hurt every workflow that is selecting data in a random
pattern. This also includes:

- `.sel` if you aren't using slices explicitly
- `.isel`
- `.sortby`
- `groupby(...).quantile()`
- and many more.

We expect all of these workloads to be substantially improved now.

Additionally, [Dask improved a lot of things](https://docs.dask.org/en/stable/changelog.html#v2024-11-1) related to either increasing chunksizes or fragmentation
of chunks over the cycle of a workload with more improvements to come. This will help a lot of
users to get better and more reliable performance.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dask-detrending/input-array.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dask-detrending/output-array-new.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dask-detrending/output-array-old.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2ca3fd4

Please sign in to comment.