Skip to content

Commit

Permalink
Avoid rechunking in reshape with chunksize=1
Browse files Browse the repository at this point in the history
When the slow-moving (early) axes in `.reshape` are all size 1, then we
can avoid an intermediate rechunk which could cause memory issues.

```
00 01 | 02 03   # a[0, :, :]
----- | -----
04 05 | 06 07
08 09 | 10 11

=============

12 13 | 14 15   # a[1, :, :]
----- | -----
16 17 | 18 19
20 21 | 22 23

-> (3, 4)

00 01 | 02 03
----- | -----
04 05 | 06 07
08 09 | 10 11
----- | -----
12 13 | 14 15
----- | -----
16 17 | 18 19
20 21 | 22 23
```

xref dask#5544, specifically the examples
given in dask#5544 (comment).
  • Loading branch information
TomAugspurger committed Oct 19, 2020
1 parent e85942f commit 2c84f1d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
30 changes: 20 additions & 10 deletions dask/array/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,26 @@ def reshape_rechunk(inshape, outshape, inchunks):
if reduce(mul, inshape[ileft : ii + 1]) != dout:
raise ValueError("Shapes not compatible")

for i in range(ileft + 1, ii + 1): # need single-shape dimensions
result_inchunks[i] = (inshape[i],) # chunks[i] = (4,)

chunk_reduction = reduce(mul, map(len, inchunks[ileft + 1 : ii + 1]))
result_inchunks[ileft] = expand_tuple(inchunks[ileft], chunk_reduction)

prod = reduce(mul, inshape[ileft + 1 : ii + 1]) # 16
result_outchunks[oi] = tuple(
prod * c for c in result_inchunks[ileft]
) # (1, 1, 1, 1) .* 16
# Special case to avoid intermediate rechunking:
# When all the lower axis are completely chunked (chunksize=1) then
# we're simply moving around blocks.
if max(max(inchunks[i]) for i in range(ileft + 1)) == 1:
for i in range(ii + 1):
result_inchunks[i] = inchunks[i]
result_outchunks[oi] = inchunks[i] * np.prod(
list(map(len, inchunks[:i]))
)
else:
for i in range(ileft + 1, ii + 1): # need single-shape dimensions
result_inchunks[i] = (inshape[i],) # chunks[i] = (4,)

chunk_reduction = reduce(mul, map(len, inchunks[ileft + 1 : ii + 1]))
result_inchunks[ileft] = expand_tuple(inchunks[ileft], chunk_reduction)

prod = reduce(mul, inshape[ileft + 1 : ii + 1]) # 16
result_outchunks[oi] = tuple(
prod * c for c in result_inchunks[ileft]
) # (1, 1, 1, 1) .* 16

oi -= 1
ii = ileft - 1
Expand Down
43 changes: 43 additions & 0 deletions dask/array/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,46 @@ def test_reshape_unknown_sizes():
a.reshape((60, -1, -1))
with pytest.raises(ValueError):
A.reshape((60, -1, -1))


@pytest.mark.parametrize(
"inshape, inchunks, outshape, outchunks",
[
# (2, 3, 4) -> (6, 4)
((2, 3, 4), ((1, 1), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (1, 2, 3, 4) -> (12, 4)
((1, 2, 3, 4), ((1,), (1, 1), (1, 2), (2, 2)), (6, 4), ((1, 2, 1, 2), (2, 2))),
# (2, 2, 3, 4) -> (12, 4)
(
(2, 2, 3, 4),
((1, 1), (1, 1), (1, 2), (2, 2)),
(12, 4),
((1, 2, 1, 2, 1, 2, 1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
(
(2, 2, 3, 4),
((1, 1), (1, 1), (1, 2), (2, 2)),
(4, 3, 4),
((1, 1, 1, 1), (1, 2), (2, 2)),
),
# (2, 2, 3, 4) -> (4, 3, 4)
((2, 2, 3, 4), ((1, 1), (2,), (1, 2), (4,)), (4, 3, 4), ((2, 2), (1, 2), (4,))),
],
)
def test_reshape_all_chunked_no_merge(inshape, inchunks, outshape, outchunks):
# https://github.com/dask/dask/issues/5544#issuecomment-712280433
# When the early axes are completely chunked then we are just moving blocks
# and can avoid any rechunking. The outchunks will always be ...
base = np.arange(np.prod(inshape)).reshape(inshape)
a = da.from_array(base, chunks=inchunks)

# test directly
inchunks2, outchunks2 = reshape_rechunk(a.shape, outshape, inchunks)
assert inchunks2 == inchunks
assert outchunks2 == outchunks

# and via reshape
result = a.reshape(outshape)
assert result.chunks == outchunks
assert_eq(result, base.reshape(outshape))

0 comments on commit 2c84f1d

Please sign in to comment.