Skip to content

Commit

Permalink
Add option to control rechunking in reshape (dask#6753)
Browse files Browse the repository at this point in the history
* Avoid rechunking in reshape with chunksize=1

When the slow-moving (early) axes in `.reshape` are all size 1, then we
can avoid an intermediate rechunk which could cause memory issues.

```
00 01 | 02 03   # a[0, :, :]
----- | -----
04 05 | 06 07
08 09 | 10 11

=============

12 13 | 14 15   # a[1, :, :]
----- | -----
16 17 | 18 19
20 21 | 22 23

-> (3, 4)

00 01 | 02 03
----- | -----
04 05 | 06 07
08 09 | 10 11
----- | -----
12 13 | 14 15
----- | -----
16 17 | 18 19
20 21 | 22 23
```

xref dask#5544, specifically the examples
given in dask#5544 (comment).

* fix conditioni

* remove breakpoint comment

* API: Added merge_chunks to reshape

Adds a keyword to reshape to control merge / rechunking. See the
documentation for an explanation.

* update images
  • Loading branch information
TomAugspurger authored and abduhbm committed Jan 19, 2021
1 parent c61244d commit 0a18611
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 4 deletions.
10 changes: 8 additions & 2 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,12 +1858,18 @@ def choose(self, choices):
return choose(self, choices)

@derived_from(np.ndarray)
def reshape(self, *shape):
def reshape(self, *shape, merge_chunks=True):
"""
.. note::
See :meth:`dask.array.reshape` for an explanation of
the ``merge_chunks`` keyword.
"""
from .reshape import reshape

if len(shape) == 1 and not isinstance(shape[0], Number):
shape = shape[0]
return reshape(self, shape)
return reshape(self, shape, merge_chunks=merge_chunks)

def topk(self, k, axis=-1, split_every=None):
"""The top k elements of an array.
Expand Down
27 changes: 25 additions & 2 deletions dask/array/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,24 @@ def contract_tuple(chunks, factor):
return tuple(out)


def reshape(x, shape):
def reshape(x, shape, merge_chunks=True):
"""Reshape array to new shape
Parameters
----------
shape : int or tuple of ints
The new shape should be compatible with the original shape. If
an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
merge_chunks : bool, default True
Whether to merge chunks using the logic in :meth:`dask.array.rechunk`
when communication is necessary given the input array chunking and
the output shape. With ``merge_chunks==False``, the input array will
be rechunked to a chunksize of 1, which can create very many tasks.
Notes
-----
This is a parallelized version of the ``np.reshape`` function with the
following limitations:
Expand All @@ -158,6 +173,9 @@ def reshape(x, shape):
When communication is necessary this algorithm depends on the logic within
rechunk. It endeavors to keep chunk sizes roughly the same when possible.
See :ref:`array-chunks.reshaping` for a discussion the tradeoffs of
``merge_chunks``.
See Also
--------
dask.array.rechunk
Expand Down Expand Up @@ -201,7 +219,12 @@ def reshape(x, shape):
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, meta=meta)

# Logic for how to rechunk
# Logic or how to rechunk
din = len(x.shape)
dout = len(shape)
if not merge_chunks and din > dout:
x = x.rechunk({i: 1 for i in range(din - dout)})

inchunks, outchunks = reshape_rechunk(x.shape, shape, x.chunks)
x2 = x.rechunk(inchunks)

Expand Down
53 changes: 53 additions & 0 deletions dask/array/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,56 @@ def test_reshape_all_not_chunked_merge(
result = a.reshape(outshape)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))


@pytest.mark.parametrize(
"inshape, inchunks, outshape, outchunks",
[
# (2, 3, 4) -> (6, 4)
((2, 3, 4), ((2,), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (1, 2, 3, 4) -> (12, 4)
((1, 2, 3, 4), ((1,), (2,), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (2, 2, 3, 4) -> (12, 4) (3 cases)
(
(2, 2, 3, 4),
((1, 1), (2,), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
(
(2, 2, 3, 4),
((2,), (1, 1), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
(
(2, 2, 3, 4),
((2,), (2,), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
# TODO: I'm confused about the behavior in this case.
# (
# (2, 2, 3, 4),
# ((2,), (2,), (1, 2), (2, 2)),
# (4, 3, 4),
# ((1, 1, 1, 1), (1, 2), (2, 2)),
# ),
# (2, 2, 3, 4) -> (4, 3, 4)
((2, 2, 3, 4), ((2,), (2,), (1, 2), (4,)), (4, 3, 4), ((2, 2), (1, 2), (4,))),
],
)
def test_reshape_merge_chunks(inshape, inchunks, outshape, outchunks):
# https://github.com/dask/dask/issues/5544#issuecomment-712280433
# When the early axes are completely chunked then we are just moving blocks
# and can avoid any rechunking. The outchunks will always be ...
base = np.arange(np.prod(inshape)).reshape(inshape)
a = da.from_array(base, chunks=inchunks)

# and via reshape
result = a.reshape(outshape, merge_chunks=False)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))

assert result.chunks != a.reshape(outshape).chunks
66 changes: 66 additions & 0 deletions docs/source/array-chunks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,72 @@ You can pass rechunk any valid chunking form:
x = x.rechunk({0: 50, 1: 1000})
.. _array-chunks.reshaping:

Reshaping
---------

The efficiency of :func:`dask.array.reshape` can depend strongly on the chunking
of the input array. In reshaping operations, there's the concept of "fast-moving"
or "high" axes. For a 2d array the second axis (``axis=1``) is the fastest-moving,
followed by the first. This means that if we draw a line indicating how values
are filled, we move across the "columns" first (along ``axis=1``), and then down
to the next row.

.. image:: images/reshape.png

Now consider the impact of Dask's chunking on this operation. If the slow-moving
axis (just ``axis=0`` in this case) has chunks larger than size 1, we run into
a problem.

.. image:: images/reshape_problem.png

The first block has a shape ``(2, 2)``. Following the rules of ``reshape`` we
take the two values from the first row of block 1. But then we cross a chunk
boundary (from 1 to 2) while we still have two "unused" values in the first
block. There's no way to line up the input blocks with the output shape. We
need to somehow rechunk the input to be compatible with the output shape. We
have two options

1. Merge chunks using the logic in :meth:`dask.arary.rechunk`. This avoids
making two many tasks / blocks, at the cost of some communication and
larger intermediates. This is the default behavior.
2. Use ``da.reshape(x, shape, merge_chunks=False)`` to avoid merging chunks
by *splitting the input*. In particular, we can rechunk all the
slow-moving axes to have a chunksize of 1. This avoids
communication and moving around large amounts of data, at the cost of
a larger task graph (potentially much larger, since the number of chunks
on the slow-moving axes will equal the length of those axes.).

Visually, here's the second option:

.. image:: images/reshape_rechunked.png

Which if these is better depends on your problem. If communication is very
expensive and your data is relatively small along the slow-moving axes, then
``merge_chunks=False`` may be better. Let's compare the task graphs of these
two on a problem reshaping a 3-d array to a 2-d, where the input array doesn't
have ``chunksize=1`` on the slow-moving axes.

.. code-block:: python
>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((2,), (2, 1), (2, 2)))
>>> a
dask.array<array, shape=(2, 3, 4), dtype=int64, chunksize=(2, 2, 2), chunktype=numpy.ndarray>
>>> a.reshape(6, 4).visualize()
.. image:: images/merge_chunks.png

.. code-block:: python
>>> a.reshape(6, 4, merge_chunks=False).visualize()
.. image:: images/merge_chunks_false.png

By default, some intermediate chunks chunks are merged, leading to a more complicated task
graph. With ``merge_chunks=False`` we split the input chunks (leading to more overall tasks,
depending on the size of the array) but avoid later communication.

Automatic Chunking
------------------

Expand Down
Binary file added docs/source/images/merge_chunks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/merge_chunks_false.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/reshape.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/reshape_problem.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/images/reshape_rechunked.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 0a18611

Please sign in to comment.