Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): custom reopen with read_elem_as_dask for remote h5ad #1665

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .registry import _LAZY_REGISTRY, IOSpec

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import Literal, ParamSpec, TypeVar

from ..._core.sparse_dataset import _CSCDataset, _CSRDataset
Expand All @@ -36,7 +36,7 @@
@contextmanager
def maybe_open_h5(
path_or_group: Path | ZarrGroup, elem_name: str
) -> Generator[StorageType, None, None]:
) -> Callable[[], Iterator[StorageType]]:
Copy link
Member

@flying-sheep flying-sheep Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous type was correct, the new one isn’t.

The way decorators interact with typing is that you normally type the decorated function (e.g. if it contains yield, the return type is Generator). The decorator than transforms the function from whatever it is to whatever the decorator wants.

I.e.

@contextmanager
def foo(*args: Unpack[Args]) -> Generator[Ret, None, None]: ...

is the same as

def _foo(*args: Unpack[Args]) -> Generator[Ret, None, None]: ...

foo: Callable[Args, AbstractContextManager[Ret]] = contextmanager(_foo)

(I’m not 100% sure I got the “unpack” syntax right, but you know what I mean)

if not isinstance(path_or_group, Path):
yield path_or_group
return
Expand Down Expand Up @@ -67,13 +67,18 @@ def make_dask_chunk(
*,
wrap: Callable[[ArrayStorageType], ArrayStorageType]
| Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g,
reopen: None | Callable[[], Iterator[StorageType]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the idea is “reopen is a callable that can be transformed into a context manager using contextlib.contextmanager.

Why not just “reopen is a callable that returns an contextlib.AbstractContextManager[StorageType]?

):
if block_info is None:
msg = "Block info is required"
raise ValueError(msg)
# We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed`
# https://github.com/scverse/anndata/issues/1105
with maybe_open_h5(path_or_group, elem_name) as f:
with (
contextmanager(reopen)()
if reopen is not None
else maybe_open_h5(path_or_group, elem_name)
) as f:
mtx = wrap(f)
idx = tuple(
slice(start, stop) for start, stop in block_info[None]["array-location"]
Expand All @@ -91,6 +96,7 @@ def read_sparse_as_dask(
*,
_reader: DaskReader,
chunks: tuple[int, ...] | None = None, # only tuple[int, int] is supported here
reopen: None | Callable[[], Iterator[StorageType]] = None,
) -> DaskArray:
import dask.array as da

Expand Down Expand Up @@ -120,7 +126,7 @@ def read_sparse_as_dask(
)
memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix
make_chunk = partial(
make_dask_chunk, path_or_group, elem_name, wrap=ad.sparse_dataset
make_dask_chunk, path_or_group, elem_name, wrap=ad.sparse_dataset, reopen=reopen
)
da_mtx = da.map_blocks(
make_chunk,
Expand All @@ -133,7 +139,11 @@ def read_sparse_as_dask(

@_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0"))
def read_h5_array(
elem: H5Array, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None
elem: H5Array,
*,
_reader: DaskReader,
chunks: tuple[int, ...] | None = None,
reopen: None | Callable[[], Iterator[StorageType]] = None,
) -> DaskArray:
import dask.array as da

Expand All @@ -156,7 +166,11 @@ def read_h5_array(

@_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0"))
def read_zarr_array(
elem: ZarrArray, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None
elem: ZarrArray,
*,
_reader: DaskReader,
chunks: tuple[int, ...] | None = None,
reopen: None | Callable[[], Iterator[StorageType]] = None,
) -> DaskArray:
chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks
import dask.array as da
Expand Down
14 changes: 9 additions & 5 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from anndata.compat import DaskArray, _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable, Generator, Iterable, Iterator
from typing import Any

from anndata._types import (
Expand Down Expand Up @@ -289,6 +289,7 @@ def read_elem(
elem: StorageType,
modifiers: frozenset[str] = frozenset(),
chunks: tuple[int, ...] | None = None,
reopen: None | Callable[[], Iterator[StorageType]] = None,
) -> DaskArray:
"""Read a dask element from a store. See exported function for more details."""

Expand All @@ -299,7 +300,7 @@ def read_elem(
if self.callback is not None:
msg = "Dask reading does not use a callback. Ignoring callback."
warnings.warn(msg, stacklevel=2)
return read_func(elem, chunks=chunks)
return read_func(elem, chunks=chunks, reopen=reopen)


class Writer:
Expand Down Expand Up @@ -379,7 +380,9 @@ def read_elem(elem: StorageType) -> RWAble:


def read_elem_as_dask(
elem: StorageType, chunks: tuple[int, ...] | None = None
elem: StorageType,
chunks: tuple[int, ...] | None = None,
reopen: None | Callable[[], Iterator[StorageType]] = None,
) -> DaskArray:
"""
Read an element from a store lazily.
Expand All @@ -395,12 +398,13 @@ def read_elem_as_dask(
chunks, optional
length `n`, the same `n` as the size of the underlying array.
Note that the minor axis dimension must match the shape for sparse.

reopen, optional
A custom function for re-opening your store in the dask reader.
Returns
-------
DaskArray
"""
return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks)
return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks, reopen=reopen)


def write_elem(
Expand Down
Loading