diff --git a/pyproject.toml b/pyproject.toml index 2af90bb..cc34118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ dependencies = [ "pandas", "pyarrow", "networkx", - # for debug logging (referenced from the issue template) "session-info", + "zarr", ] [project.optional-dependencies] @@ -51,6 +51,7 @@ test = [ "pytest", "coverage", "scanpy", + "joblib", ] [tool.coverage.run] diff --git a/src/treedata/__init__.py b/src/treedata/__init__.py index 25f3808..8c8c6d5 100644 --- a/src/treedata/__init__.py +++ b/src/treedata/__init__.py @@ -1,6 +1,7 @@ from importlib.metadata import version from ._core.merge import concat +from ._core.read import read_h5ad, read_zarr from ._core.treedata import TreeData __version__ = version("treedata") diff --git a/src/treedata/_core/aligned_mapping.py b/src/treedata/_core/aligned_mapping.py index 3efe2c6..ca01718 100755 --- a/src/treedata/_core/aligned_mapping.py +++ b/src/treedata/_core/aligned_mapping.py @@ -185,6 +185,8 @@ def __init__( self._dimnames = ("obs", "var") self.subset_idx = subset_idx self._axis = parent_mapping._axis + self._tree_to_leaf = parent_mapping._tree_to_leaf + self._leaf_to_tree = parent_mapping._leaf_to_tree def __getitem__(self, key: str) -> nx.DiGraph: leaves = self.parent_mapping._tree_to_leaf[key] diff --git a/src/treedata/_core/read.py b/src/treedata/_core/read.py new file mode 100755 index 0000000..ad756b6 --- /dev/null +++ b/src/treedata/_core/read.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from collections.abc import MutableMapping, Sequence +from pathlib import Path +from typing import ( + Literal, +) + +import anndata as ad +import zarr +from scipy import sparse + +from treedata._core.aligned_mapping import AxisTrees +from treedata._core.treedata import TreeData +from treedata._utils import dict_to_digraph + + +def _tdata_from_adata(tdata) -> TreeData: + """Create a TreeData object parsing attribute from AnnData uns field.""" + tdata.__class__ = TreeData + if "treedata_attrs" in tdata.uns.keys(): + treedata_attrs = tdata.uns["treedata_attrs"] + tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None + tdata._allow_overlap = bool(treedata_attrs["allow_overlap"]) + tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()}) + tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()}) + del tdata.uns["treedata_attrs"] + else: + tdata._tree_label = None + tdata._allow_overlap = False + tdata._obst = AxisTrees(tdata, 0) + tdata._vart = AxisTrees(tdata, 1) + return tdata + + +def read_h5ad( + filename: str | Path = None, + backed: Literal["r", "r+"] | bool | None = None, + *, + as_sparse: Sequence[str] = (), + as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix, + chunk_size: int = 6000, +) -> TreeData: + """Read `.h5ad`-formatted hdf5 file. + + Parameters + ---------- + filename + File name of data file. + backed + If `'r'`, load :class:`~anndata.TreeData` in `backed` mode + instead of fully loading it into memory (`memory` mode). + If you want to modify backed attributes of the TreeData object, + you need to choose `'r+'`. + as_sparse + If an array was saved as dense, passing its name here will read it as + a sparse_matrix, by chunk of size `chunk_size`. + as_sparse_fmt + Sparse format class to read elements from `as_sparse` in as. + chunk_size + Used only when loading sparse dataset that is stored as dense. + Loading iterates through chunks of the dataset of this row size + until it reads the whole dataset. + Higher size means higher memory consumption and higher (to a point) + loading speed. + """ + adata = ad.read_h5ad( + filename, + backed=backed, + as_sparse=as_sparse, + as_sparse_fmt=as_sparse_fmt, + chunk_size=chunk_size, + ) + return _tdata_from_adata(adata) + + +def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData: + """Read from a hierarchical Zarr array store. + + Parameters + ---------- + store + The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class. + """ + adata = ad.read_zarr(store) + return _tdata_from_adata(adata) diff --git a/src/treedata/_core/treedata.py b/src/treedata/_core/treedata.py index 758c66d..94a9869 100755 --- a/src/treedata/_core/treedata.py +++ b/src/treedata/_core/treedata.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, Sequence +import warnings +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -14,6 +15,8 @@ from anndata._core.index import Index, Index1D from scipy import sparse +from treedata._utils import digraph_to_dict + from .aligned_mapping import ( AxisTrees, ) @@ -64,7 +67,7 @@ class TreeData(ad.AnnData): filemode Open mode of backing file. See :class:`h5py.File`. asview - Initialize as view. `X` has to be an AnnData object. + Initialize as view. `X` has to be an TreeData object. label Columns in `.obs` and `.var` to place tree key in. Default is "tree". If it's None, no column is added. @@ -224,6 +227,11 @@ def label(self) -> str | None: """Column in `.obs` and .`obs` with tree keys""" return self._tree_label + @property + def is_view(self) -> bool: + """`True` if object is view of another TreeData object, `False` otherwise.""" + return self._is_view + @obst.setter def obst(self, value): obst = AxisTrees(self, 0, vals=dict(value)) @@ -265,14 +273,122 @@ def to_adata(self) -> ad.AnnData: """Convert this TreeData object to an AnnData object.""" return ad.AnnData(self) - def copy(self) -> TreeData: - """Full copy of the object.""" - adata = super().copy() - treedata_copy = TreeData( + def _treedata_attrs(self) -> dict: + """Dictionary of TreeData attributes""" + return { + "obst": {k: digraph_to_dict(v) for k, v in self.obst.items()}, + "vart": {k: digraph_to_dict(v) for k, v in self.vart.items()}, + "label": self.label, + "allow_overlap": self.allow_overlap, + } + + def copy(self, filename: PathLike | None = None) -> TreeData: + """Full copy, optionally on disk""" + adata = super().copy(filename=filename) + if not self.isbacked: + treedata_copy = TreeData( + adata, + obst=self.obst.copy(), + vart=self.vart.copy(), + label=self.label, + allow_overlap=self.allow_overlap, + ) + else: + from .read import read_h5ad + + if filename is None: + raise ValueError( + "To copy an TreeData object in backed mode, " + "pass a filename: `.copy(filename='myfilename.h5ad')`. " + "To load the object into memory, use `.to_memory()`." + ) + mode = self.file._filemode + adata.uns["treedata_attrs"] = self._treedata_attrs() + adata.write_h5ad(filename) + treedata_copy = read_h5ad(filename, backed=mode) + return treedata_copy + + def transpose(self) -> TreeData: + """Transpose whole object + + Data matrix is transposed, observations and variables are interchanged. + Ignores `.raw`. + """ + adata = super().transpose() + treedata_transpose = TreeData( + adata, + obst=self.vart.copy(), + vart=self.obst.copy(), + label=self.label, + allow_overlap=self.allow_overlap, + ) + return treedata_transpose + + T = property(transpose) + + def write_h5ad( + self, + filename: PathLike | None = None, + compression: Literal["gzip", "lzf"] | None = None, + compression_opts: int | Any = None, + as_dense: Sequence[str] = (), + ): + """Write `.h5ad`-formatted hdf5 file. + + Parameters + ---------- + filename + Filename of data file. Defaults to backing file. + compression + [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. + compression_opts + [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. + as_dense + Sparse arrays in TreeData object to write as dense. Currently only + supports `X` and `raw/X`. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.uns["treedata_attrs"] = self._treedata_attrs() + super().write_h5ad( + filename=filename, compression=compression, compression_opts=compression_opts, as_dense=as_dense + ) + self.uns.pop("treedata_attrs") + + write = write_h5ad # a shortcut and backwards compat + + def write_zarr( + self, + store: MutableMapping | PathLike, + chunks: bool | int | tuple[int, ...] | None = None, + ): + """Write a hierarchical Zarr array store. + + Parameters + ---------- + store + The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class. + chunks + Chunk shape. + """ + adata = self.to_adata() + adata.uns["treedata_attrs"] = self._treedata_attrs() + adata.write_zarr(store=store, chunks=chunks) + + def to_memory(self, copy=False) -> TreeData: + """Return a new AnnData object with all backed arrays loaded into memory. + + Params + ------ + copy: + Whether the arrays that are already in-memory should be copied. + """ + adata = super().to_memory(copy) + tdata = TreeData( adata, obst=self.obst.copy(), vart=self.vart.copy(), label=self.label, allow_overlap=self.allow_overlap, ) - return treedata_copy + return tdata diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index 9f360fd..eaea40c 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -34,3 +34,28 @@ def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph: # The combined_tree now contains all nodes and edges from the subsets return combined_tree + + +def digraph_to_dict(G: nx.DiGraph) -> dict: + """Convert a networkx.DiGraph to a dictionary.""" + G = nx.DiGraph(G) + edge_dict = nx.to_dict_of_dicts(G) + # Get node data + node_dict = {node: G.nodes[node] for node in G.nodes()} + # Combine edge and node data in one dictionary + graph_dict = {"edges": edge_dict, "nodes": node_dict} + + return graph_dict + + +def dict_to_digraph(graph_dict: dict) -> nx.DiGraph: + """Convert a dictionary to a networkx.DiGraph.""" + G = nx.DiGraph() + # Add nodes and their attributes + for node, attrs in graph_dict["nodes"].items(): + G.add_node(node, **attrs) + # Add edges and their attributes + for source, targets in graph_dict["edges"].items(): + for target, attrs in targets.items(): + G.add_edge(source, target, **attrs) + return G diff --git a/tests/test_base.py b/tests/test_base.py index 582cc85..e927de9 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -25,8 +25,7 @@ def tree(): def check_graph_equality(g1, g2): - assert g1.nodes == g2.nodes - assert g1.edges == g2.edges + assert nx.is_isomorphic(g1, g2, node_match=lambda n1, n2: n1 == n2, edge_match=lambda e1, e2: e1 == e2) def test_creation(X, adata, tree): @@ -173,3 +172,12 @@ def test_copy(adata, tree): assert np.array_equal(treedata.X, treedata_copy.X) assert treedata.obst["tree"].nodes == treedata_copy.obst["tree"].nodes assert treedata.obst["tree"].edges == treedata_copy.obst["tree"].edges + + +def test_transpose(adata, tree): + treedata = td.TreeData(adata, obst={"tree": tree}) + treedata_transpose = treedata.transpose() + assert np.array_equal(treedata.X.T, treedata_transpose.X) + assert treedata.obst["tree"].nodes == treedata_transpose.vart["tree"].nodes + assert treedata_transpose.obst_keys() == [] + assert np.array_equal(treedata.obs_names, treedata.T.obs_names) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py new file mode 100755 index 0000000..f53f927 --- /dev/null +++ b/tests/test_readwrite.py @@ -0,0 +1,113 @@ +import joblib +import networkx as nx +import numpy as np +import pytest + +import treedata as td + + +@pytest.fixture +def X(): + yield np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +@pytest.fixture +def tree(): + tree = nx.DiGraph() + tree.add_edges_from([("root", "0"), ("root", "1")]) + tree["root"]["0"]["length"] = 1 + tree.nodes["root"]["depth"] = 0 + yield tree + + +@pytest.fixture +def tdata(X, tree): + yield td.TreeData(X, obst={"tree": tree}, vart={"tree": tree}, label=None, allow_overlap=False) + + +def check_graph_equality(g1, g2): + assert nx.is_isomorphic(g1, g2, node_match=lambda n1, n2: n1 == n2, edge_match=lambda e1, e2: e1 == e2) + + +def test_h5ad_readwrite(tdata, tmp_path): + # not backed + file_path = tmp_path / "test.h5ad" + tdata.write_h5ad(file_path) + tdata2 = td.read_h5ad(file_path) + assert np.array_equal(tdata2.X, tdata.X) + check_graph_equality(tdata2.obst["tree"], tdata.obst["tree"]) + check_graph_equality(tdata2.vart["tree"], tdata.vart["tree"]) + assert tdata2.label is None + assert tdata2.allow_overlap is False + # backed + tdata2 = td.read_h5ad(file_path, backed="r") + assert np.array_equal(tdata2.X, tdata.X) + check_graph_equality(tdata2.obst["tree"], tdata.obst["tree"]) + check_graph_equality(tdata2.vart["tree"], tdata.vart["tree"]) + assert tdata2.label is None + assert tdata2.allow_overlap is False + assert tdata2.isbacked + assert tdata2.file.is_open + assert tdata2.filename == file_path + + +def test_zarr_readwrite(tdata, tmp_path): + tdata.write_zarr(tmp_path / "test.zarr") + tdata2 = td.read_zarr(tmp_path / "test.zarr") + assert np.array_equal(tdata2.X, tdata.X) + check_graph_equality(tdata2.obst["tree"], tdata.obst["tree"]) + check_graph_equality(tdata2.vart["tree"], tdata.vart["tree"]) + assert tdata2.label is None + assert tdata2.allow_overlap is False + + +def test_read_anndata(tdata, tmp_path): + adata = tdata.to_adata() + file_path = tmp_path / "test.h5ad" + adata.write_h5ad(file_path) + tdata = td.read_h5ad(file_path) + assert np.array_equal(tdata.X, adata.X) + assert tdata.label is None + assert tdata.allow_overlap is False + assert tdata.obst_keys() == [] + + +def test_h5ad_backing(tdata, tree, tmp_path): + tdata_copy = tdata.copy() + assert not tdata.isbacked + backing_h5ad = tmp_path / "test.h5ad" + tdata.filename = backing_h5ad + # backing mode + tdata.write() + assert not tdata.file.is_open + assert tdata.isbacked + # view of backed object + tdata_subset = tdata[:, 0] + subset_hash = joblib.hash(tdata_subset) + assert tdata_subset.is_view + assert tdata_subset.isbacked + assert tdata_subset.shape == (3, 1) + check_graph_equality(tdata_subset.obst["tree"], tdata.obst["tree"]) + assert np.array_equal(tdata_subset.X, tdata_copy.X[:, 0].reshape(-1, 1)) + # cannot set view in backing mode... + with pytest.warns(UserWarning): + with pytest.raises(ValueError): + tdata_subset.obs["foo"] = range(3) + # with pytest.warns(UserWarning): + # with pytest.raises(ValueError): + # tdata_subset.obst["foo"] = tree + assert subset_hash == joblib.hash(tdata_subset) + assert tdata_subset.is_view + # copy + tdata_subset = tdata_subset.copy(tmp_path / "test.subset.h5ad") + assert not tdata_subset.is_view + tdata_subset.obs["foo"] = range(3) + assert not tdata_subset.is_view + assert tdata_subset.isbacked + assert tdata_subset.obs["foo"].tolist() == list(range(3)) + tdata_subset.write() + # move to memory + tdata_subset = tdata_subset.to_memory() + assert not tdata_subset.is_view + assert not tdata_subset.isbacked + check_graph_equality(tdata_subset.obst["tree"], tdata.obst["tree"]) diff --git a/tests/test_views.py b/tests/test_views.py index 2943e9e..eee288f 100755 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -34,16 +34,6 @@ def test_views(tdata): assert tdata_subset.obs["test"].tolist() == list(range(2)) -# this test should pass once anndata bug is fixed -# See https://github.com/scverse/anndata/issues/1382 -@pytest.mark.xfail -def test_views_creation(tdata): - tdata_view = td.TreeData(tdata, asview=True) - assert tdata_view.is_view - with pytest.raises(ValueError): - _ = td.TreeData(np.zeros(shape=(3, 3)), asview=False) - - def test_views_subset_tree(tdata): expected_edges = [ ("0", "1"), @@ -101,6 +91,15 @@ def test_views_set(tdata): assert not tdata_subset.is_view assert list(tdata_subset.obst.keys()) == ["tree", "new_tree"] assert list(tdata_subset.obst["new_tree"].edges) == [("0", "8")] + # good assignment no overlap + tree = nx.DiGraph([("root", "0"), ("root", "1")]) + good_tree = nx.DiGraph([("root", "2"), ("root", "3")]) + tdata = td.TreeData(X=np.zeros((8, 8)), allow_overlap=False, obst={"tree": tree}) + print(tdata.obs_names) + tdata_subset = tdata[:4, :] + print(tdata_subset.obs_names) + with pytest.warns(UserWarning): + tdata_subset.obst["tree"] = good_tree def test_views_del(tdata):