Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #1744 on branch 0.11.x ((fix): cache arrays in BaseCompressedSparseDataset) #1778

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1744.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache accesses to the `data` and `indices` arrays in {class}`~anndata.abc.CSRDataset` and {class}`~anndata.abc.CSCDataset` {user}`ilan-gold`
33 changes: 25 additions & 8 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from scipy.sparse._compressed import _cs_matrix

from .._types import GroupStorageType
from ..compat import H5Array
from .index import Index
else:
from scipy.sparse import spmatrix as _cs_matrix
Expand Down Expand Up @@ -380,7 +381,7 @@ def backend(self) -> Literal["zarr", "hdf5"]:
@property
def dtype(self) -> np.dtype:
"""The :class:`numpy.dtype` of the `data` attribute of the sparse matrix."""
return self.group["data"].dtype
return self._data.dtype

@classmethod
def _check_group_format(cls, group):
Expand Down Expand Up @@ -545,16 +546,18 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
indptr[orig_data_size:] = (
sparse_matrix.indptr[1:].astype(np.int64) + indptr_offset
)
# Clear cached property
if hasattr(self, "indptr"):
del self._indptr

# indices
indices = self.group["indices"]
orig_data_size = indices.shape[0]
indices.resize((orig_data_size + sparse_matrix.indices.shape[0],))
indices[orig_data_size:] = sparse_matrix.indices

# Clear cached property
for attr in ["_indptr", "_indices", "_data"]:
if hasattr(self, attr):
delattr(self, attr)

@cached_property
def _indptr(self) -> np.ndarray:
"""\
Expand All @@ -565,11 +568,25 @@ def _indptr(self) -> np.ndarray:
arr = self.group["indptr"][...]
return arr

@cached_property
def _indices(self) -> H5Array | ZarrArray:
"""\
Cache access to the indices to prevent unnecessary reads of the zarray
"""
return self.group["indices"]

@cached_property
def _data(self) -> H5Array | ZarrArray:
"""\
Cache access to the data to prevent unnecessary reads of the zarray
"""
return self.group["data"]

def _to_backed(self) -> BackedSparseMatrix:
format_class = get_backed_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.data = self._data
mtx.indices = self._indices
mtx.indptr = self._indptr
return mtx

Expand All @@ -578,8 +595,8 @@ def to_memory(self) -> ss.csr_matrix | ss.csc_matrix | SpArray:
self.format, use_sparray_in_io=settings.use_sparse_array_on_read
)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.data = self._data[...]
mtx.indices = self._indices[...]
mtx.indptr = self._indptr
return mtx

Expand Down
54 changes: 32 additions & 22 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, overload

import h5py
import numpy as np
from scipy import sparse

import anndata as ad
from anndata.abc import CSCDataset, CSRDataset

from ..._core.file_backing import filename, get_elem_name
from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup
from .registry import _LAZY_REGISTRY, IOSpec

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from typing import Literal, ParamSpec, TypeVar

from ..._core.sparse_dataset import _CSCDataset, _CSRDataset
from ..._types import ArrayStorageType, StorageType
from ...compat import DaskArray
from ...compat import DaskArray, H5File, SpArray
from .registry import DaskReader

BlockInfo = Mapping[
Expand All @@ -31,16 +30,25 @@

P = ParamSpec("P")
R = TypeVar("R")
D = TypeVar("D")


@overload
@contextmanager
def maybe_open_h5(
path_or_group: Path | ZarrGroup, elem_name: str
) -> Generator[StorageType, None, None]:
if not isinstance(path_or_group, Path):
yield path_or_group
path_or_other: Path, elem_name: str
) -> Generator[H5File, None, None]: ...
@overload
@contextmanager
def maybe_open_h5(path_or_other: D, elem_name: str) -> Generator[D, None, None]: ...
@contextmanager
def maybe_open_h5(
path_or_other: H5File | D, elem_name: str
) -> Generator[H5File | D, None, None]:
if not isinstance(path_or_other, Path):
yield path_or_other
return
file = h5py.File(path_or_group, "r")
file = h5py.File(path_or_other, "r")
try:
yield file[elem_name]
finally:
Expand All @@ -61,20 +69,17 @@ def compute_chunk_layout_for_axis_shape(


def make_dask_chunk(
path_or_group: Path | ZarrGroup,
path_or_sparse_dataset: Path | D,
elem_name: str,
block_info: BlockInfo | None = None,
*,
wrap: Callable[[ArrayStorageType], ArrayStorageType]
| Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g,
):
) -> sparse.csr_matrix | sparse.csc_matrix | SpArray:
if block_info is None:
msg = "Block info is required"
raise ValueError(msg)
# We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed`
# https://github.com/scverse/anndata/issues/1105
with maybe_open_h5(path_or_group, elem_name) as f:
mtx = wrap(f)
with maybe_open_h5(path_or_sparse_dataset, elem_name) as f:
mtx = ad.io.sparse_dataset(f) if isinstance(f, H5Group) else f
idx = tuple(
slice(start, stop) for start, stop in block_info[None]["array-location"]
)
Expand All @@ -94,10 +99,17 @@ def read_sparse_as_dask(
) -> DaskArray:
import dask.array as da

path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem
path_or_sparse_dataset = (
Path(filename(elem))
if isinstance(elem, H5Group)
else ad.io.sparse_dataset(elem)
)
elem_name = get_elem_name(elem)
shape: tuple[int, int] = tuple(elem.attrs["shape"])
dtype = elem["data"].dtype
if isinstance(path_or_sparse_dataset, CSRDataset | CSCDataset):
dtype = path_or_sparse_dataset.dtype
else:
dtype = elem["data"].dtype
is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix"

stride: int = _DEFAULT_STRIDE
Expand All @@ -123,9 +135,7 @@ def read_sparse_as_dask(
(chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor)
)
memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix
make_chunk = partial(
make_dask_chunk, path_or_group, elem_name, wrap=ad.io.sparse_dataset
)
make_chunk = partial(make_dask_chunk, path_or_sparse_dataset, elem_name)
da_mtx = da.map_blocks(
make_chunk,
dtype=dtype,
Expand Down
72 changes: 58 additions & 14 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import anndata as ad
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.compat import CAN_USE_SPARSE_ARRAY, SpArray
from anndata._io.specs.registry import read_elem_as_dask
from anndata.compat import CAN_USE_SPARSE_ARRAY, DaskArray, SpArray
from anndata.experimental import read_dispatched
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

Expand All @@ -26,6 +27,9 @@
from numpy.typing import ArrayLike, NDArray
from pytest_mock import MockerFixture

from anndata.abc import CSCDataset, CSRDataset
from anndata.compat import ZarrGroup

Idx = slice | int | NDArray[np.integer] | NDArray[np.bool_]


Expand Down Expand Up @@ -281,6 +285,25 @@ def test_dataset_append_memory(
assert_equal(fromdisk, frommem)


def test_append_array_cache_bust(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = tmp_path / f"test.{diskfmt.replace('ad', '')}"
a = sparse.random(100, 100, format="csr")
if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")
ad.io.write_elem(f, "mtx", a)
ad.io.write_elem(f, "mtx_2", a)
diskmtx = sparse_dataset(f["mtx"])
old_array_shapes = {}
array_names = ["indptr", "indices", "data"]
for name in array_names:
old_array_shapes[name] = getattr(diskmtx, f"_{name}").shape
diskmtx.append(sparse_dataset(f["mtx_2"]))
for name in array_names:
assert old_array_shapes[name] != getattr(diskmtx, f"_{name}").shape


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
@pytest.mark.parametrize(
("subset_func", "subset_func2"),
Expand Down Expand Up @@ -354,16 +377,18 @@ def test_dataset_append_disk(


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
def test_indptr_cache(
def test_lazy_array_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
elems = {"indptr", "indices", "data"}
path = tmp_path / "test.zarr"
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad.io.write_elem(f, "X", a)
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/indptr"])
for elem in elems:
store.initialize_key_trackers([f"X/{elem}"])
f = zarr.open_group(store, "a")
a_disk = sparse_dataset(f["X"])
a_disk[:1]
Expand All @@ -372,6 +397,14 @@ def test_indptr_cache(
a_disk[8:9]
# one each for .zarray and actual access
assert store.get_access_count("X/indptr") == 2
for elem_not_indptr in elems - {"indptr"}:
assert (
sum(
".zarray" in key_accessed
for key_accessed in store.get_accessed_keys(f"X/{elem_not_indptr}")
)
== 1
)


Kind = Literal["slice", "int", "array", "mask"]
Expand Down Expand Up @@ -421,27 +454,38 @@ def width_idx_kinds(
(
[0],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[0],
slice(None, 3),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[3, 4, 5],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
["X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
),
l=10,
),
)
@pytest.mark.parametrize(
"open_func",
[
sparse_dataset,
lambda x: read_elem_as_dask(
x, chunks=(1, -1) if x.attrs["encoding-type"] == "csr_matrix" else (-1, 1)
),
],
ids=["sparse_dataset", "read_elem_as_dask"],
)
def test_data_access(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
idx_maj: Idx,
idx_min: Idx,
exp: Sequence[str],
open_func: Callable[[ZarrGroup], CSRDataset | CSCDataset | DaskArray],
):
path = tmp_path / "test.zarr"
a = sparse_format(np.eye(10, 10))
Expand All @@ -454,19 +498,19 @@ def test_data_access(
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/data"])
f = zarr.open_group(store)
a_disk = sparse_dataset(f["X"])

# Do the slicing with idx
store.reset_key_trackers()
if a_disk.format == "csr":
a_disk[idx_maj, idx_min]
a_disk = AnnData(X=open_func(f["X"]))
if a.format == "csr":
subset = a_disk[idx_maj, idx_min]
else:
a_disk[idx_min, idx_maj]
subset = a_disk[idx_min, idx_maj]
if isinstance(subset.X, DaskArray):
subset.X.compute(scheduler="single-threaded")

assert store.get_access_count("X/data") == len(exp), store.get_accessed_keys(
"X/data"
)
assert store.get_accessed_keys("X/data") == exp
# dask access order is not guaranteed so need to sort
assert sorted(store.get_accessed_keys("X/data")) == sorted(exp)


@pytest.mark.parametrize(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,9 @@ def gen_list(n):


def gen_sparse(n):
return sparse.random(np.random.randint(1, 100), np.random.randint(1, 100))
return sparse.random(
np.random.randint(1, 100), np.random.randint(1, 100), format="csr"
)


def gen_something(n):
Expand Down