diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 9b76d16ce..74dcb6cf7 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -590,8 +590,11 @@ def wrap(*a, block_id=None, **kw): def rechunk(x, chunks, target_store=None): - if x.chunks == normalize_chunks(chunks, x.shape, dtype=x.dtype): + normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) + if x.chunks == normalized_chunks: return x + # normalizing takes care of dict args for chunks + target_chunks = to_chunksize(normalized_chunks) name = gensym() spec = x.spec if target_store is None: @@ -599,7 +602,7 @@ def rechunk(x, chunks, target_store=None): temp_store = new_temp_path(name=f"{name}-intermediate", spec=spec) pipeline = primitive_rechunk( x.zarray_maybe_lazy, - target_chunks=chunks, + target_chunks=target_chunks, allowed_mem=spec.allowed_mem, reserved_mem=spec.reserved_mem, target_store=target_store, diff --git a/cubed/primitive/rechunk.py b/cubed/primitive/rechunk.py index 8ec2873b3..db01e82c3 100644 --- a/cubed/primitive/rechunk.py +++ b/cubed/primitive/rechunk.py @@ -7,11 +7,7 @@ from cubed.runtime.pipeline import spec_to_pipeline from cubed.storage.zarr import lazy_empty from cubed.vendor.rechunker.algorithm import rechunking_plan -from cubed.vendor.rechunker.api import ( - _get_dims_from_zarr_array, - _shape_dict_to_tuple, - _validate_options, -) +from cubed.vendor.rechunker.api import _validate_options from cubed.vendor.rechunker.types import CopySpec @@ -120,16 +116,6 @@ def _setup_array_rechunk( # this is just a pass-through copy target_chunks = source_chunks - if isinstance(target_chunks, dict): - array_dims = _get_dims_from_zarr_array(source_array) - try: - target_chunks = _shape_dict_to_tuple(array_dims, target_chunks) - except KeyError: - raise KeyError( - "You must explicitly specify each dimension size in target_chunks. " - f"Got array_dims {array_dims}, target_chunks {target_chunks}." - ) - # TODO: rewrite to avoid the hard dependency on dask max_mem = cubed.vendor.dask.utils.parse_bytes(max_mem) diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index f54758c0f..4f79883e3 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -197,9 +197,10 @@ def test_multiple_ops(spec, executor): ) -def test_rechunk(spec, executor): +@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}]) +def test_rechunk(spec, executor, new_chunks): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec) - b = a.rechunk((1, 2)) + b = a.rechunk(new_chunks) assert_array_equal( b.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),