Skip to content

Commit

Permalink
Merge pull request #9 from YosefLab/inherited-methods
Browse files Browse the repository at this point in the history
readwrite
  • Loading branch information
colganwi authored May 9, 2024
2 parents 4210047 + 41e0b75 commit c9fd602
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 20 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ dependencies = [
"pandas",
"pyarrow",
"networkx",
# for debug logging (referenced from the issue template)
"session-info",
"zarr",
]

[project.optional-dependencies]
Expand All @@ -51,6 +51,7 @@ test = [
"pytest",
"coverage",
"scanpy",
"joblib",
]

[tool.coverage.run]
Expand Down
1 change: 1 addition & 0 deletions src/treedata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from importlib.metadata import version

from ._core.merge import concat
from ._core.read import read_h5ad, read_zarr
from ._core.treedata import TreeData

__version__ = version("treedata")
2 changes: 2 additions & 0 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def __init__(
self._dimnames = ("obs", "var")
self.subset_idx = subset_idx
self._axis = parent_mapping._axis
self._tree_to_leaf = parent_mapping._tree_to_leaf
self._leaf_to_tree = parent_mapping._leaf_to_tree

def __getitem__(self, key: str) -> nx.DiGraph:
leaves = self.parent_mapping._tree_to_leaf[key]
Expand Down
86 changes: 86 additions & 0 deletions src/treedata/_core/read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from collections.abc import MutableMapping, Sequence
from pathlib import Path
from typing import (
Literal,
)

import anndata as ad
import zarr
from scipy import sparse

from treedata._core.aligned_mapping import AxisTrees
from treedata._core.treedata import TreeData
from treedata._utils import dict_to_digraph


def _tdata_from_adata(tdata) -> TreeData:
"""Create a TreeData object parsing attribute from AnnData uns field."""
tdata.__class__ = TreeData
if "treedata_attrs" in tdata.uns.keys():
treedata_attrs = tdata.uns["treedata_attrs"]
tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None
tdata._allow_overlap = bool(treedata_attrs["allow_overlap"])
tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()})
tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()})
del tdata.uns["treedata_attrs"]
else:
tdata._tree_label = None
tdata._allow_overlap = False
tdata._obst = AxisTrees(tdata, 0)
tdata._vart = AxisTrees(tdata, 1)
return tdata


def read_h5ad(
filename: str | Path = None,
backed: Literal["r", "r+"] | bool | None = None,
*,
as_sparse: Sequence[str] = (),
as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix,
chunk_size: int = 6000,
) -> TreeData:
"""Read `.h5ad`-formatted hdf5 file.
Parameters
----------
filename
File name of data file.
backed
If `'r'`, load :class:`~anndata.TreeData` in `backed` mode
instead of fully loading it into memory (`memory` mode).
If you want to modify backed attributes of the TreeData object,
you need to choose `'r+'`.
as_sparse
If an array was saved as dense, passing its name here will read it as
a sparse_matrix, by chunk of size `chunk_size`.
as_sparse_fmt
Sparse format class to read elements from `as_sparse` in as.
chunk_size
Used only when loading sparse dataset that is stored as dense.
Loading iterates through chunks of the dataset of this row size
until it reads the whole dataset.
Higher size means higher memory consumption and higher (to a point)
loading speed.
"""
adata = ad.read_h5ad(
filename,
backed=backed,
as_sparse=as_sparse,
as_sparse_fmt=as_sparse_fmt,
chunk_size=chunk_size,
)
return _tdata_from_adata(adata)


def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData:
"""Read from a hierarchical Zarr array store.
Parameters
----------
store
The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
"""
adata = ad.read_zarr(store)
return _tdata_from_adata(adata)
130 changes: 123 additions & 7 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
import warnings
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -14,6 +15,8 @@
from anndata._core.index import Index, Index1D
from scipy import sparse

from treedata._utils import digraph_to_dict

from .aligned_mapping import (
AxisTrees,
)
Expand Down Expand Up @@ -64,7 +67,7 @@ class TreeData(ad.AnnData):
filemode
Open mode of backing file. See :class:`h5py.File`.
asview
Initialize as view. `X` has to be an AnnData object.
Initialize as view. `X` has to be an TreeData object.
label
Columns in `.obs` and `.var` to place tree key in. Default is "tree".
If it's None, no column is added.
Expand Down Expand Up @@ -224,6 +227,11 @@ def label(self) -> str | None:
"""Column in `.obs` and .`obs` with tree keys"""
return self._tree_label

@property
def is_view(self) -> bool:
"""`True` if object is view of another TreeData object, `False` otherwise."""
return self._is_view

@obst.setter
def obst(self, value):
obst = AxisTrees(self, 0, vals=dict(value))
Expand Down Expand Up @@ -265,14 +273,122 @@ def to_adata(self) -> ad.AnnData:
"""Convert this TreeData object to an AnnData object."""
return ad.AnnData(self)

def copy(self) -> TreeData:
"""Full copy of the object."""
adata = super().copy()
treedata_copy = TreeData(
def _treedata_attrs(self) -> dict:
"""Dictionary of TreeData attributes"""
return {
"obst": {k: digraph_to_dict(v) for k, v in self.obst.items()},
"vart": {k: digraph_to_dict(v) for k, v in self.vart.items()},
"label": self.label,
"allow_overlap": self.allow_overlap,
}

def copy(self, filename: PathLike | None = None) -> TreeData:
"""Full copy, optionally on disk"""
adata = super().copy(filename=filename)
if not self.isbacked:
treedata_copy = TreeData(
adata,
obst=self.obst.copy(),
vart=self.vart.copy(),
label=self.label,
allow_overlap=self.allow_overlap,
)
else:
from .read import read_h5ad

if filename is None:
raise ValueError(
"To copy an TreeData object in backed mode, "
"pass a filename: `.copy(filename='myfilename.h5ad')`. "
"To load the object into memory, use `.to_memory()`."
)
mode = self.file._filemode
adata.uns["treedata_attrs"] = self._treedata_attrs()
adata.write_h5ad(filename)
treedata_copy = read_h5ad(filename, backed=mode)
return treedata_copy

def transpose(self) -> TreeData:
"""Transpose whole object
Data matrix is transposed, observations and variables are interchanged.
Ignores `.raw`.
"""
adata = super().transpose()
treedata_transpose = TreeData(
adata,
obst=self.vart.copy(),
vart=self.obst.copy(),
label=self.label,
allow_overlap=self.allow_overlap,
)
return treedata_transpose

T = property(transpose)

def write_h5ad(
self,
filename: PathLike | None = None,
compression: Literal["gzip", "lzf"] | None = None,
compression_opts: int | Any = None,
as_dense: Sequence[str] = (),
):
"""Write `.h5ad`-formatted hdf5 file.
Parameters
----------
filename
Filename of data file. Defaults to backing file.
compression
[`lzf`, `gzip`], see the h5py :ref:`dataset_compression`.
compression_opts
[`lzf`, `gzip`], see the h5py :ref:`dataset_compression`.
as_dense
Sparse arrays in TreeData object to write as dense. Currently only
supports `X` and `raw/X`.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.uns["treedata_attrs"] = self._treedata_attrs()
super().write_h5ad(
filename=filename, compression=compression, compression_opts=compression_opts, as_dense=as_dense
)
self.uns.pop("treedata_attrs")

write = write_h5ad # a shortcut and backwards compat

def write_zarr(
self,
store: MutableMapping | PathLike,
chunks: bool | int | tuple[int, ...] | None = None,
):
"""Write a hierarchical Zarr array store.
Parameters
----------
store
The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
chunks
Chunk shape.
"""
adata = self.to_adata()
adata.uns["treedata_attrs"] = self._treedata_attrs()
adata.write_zarr(store=store, chunks=chunks)

def to_memory(self, copy=False) -> TreeData:
"""Return a new AnnData object with all backed arrays loaded into memory.
Params
------
copy:
Whether the arrays that are already in-memory should be copied.
"""
adata = super().to_memory(copy)
tdata = TreeData(
adata,
obst=self.obst.copy(),
vart=self.vart.copy(),
label=self.label,
allow_overlap=self.allow_overlap,
)
return treedata_copy
return tdata
25 changes: 25 additions & 0 deletions src/treedata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,28 @@ def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph:

# The combined_tree now contains all nodes and edges from the subsets
return combined_tree


def digraph_to_dict(G: nx.DiGraph) -> dict:
"""Convert a networkx.DiGraph to a dictionary."""
G = nx.DiGraph(G)
edge_dict = nx.to_dict_of_dicts(G)
# Get node data
node_dict = {node: G.nodes[node] for node in G.nodes()}
# Combine edge and node data in one dictionary
graph_dict = {"edges": edge_dict, "nodes": node_dict}

return graph_dict


def dict_to_digraph(graph_dict: dict) -> nx.DiGraph:
"""Convert a dictionary to a networkx.DiGraph."""
G = nx.DiGraph()
# Add nodes and their attributes
for node, attrs in graph_dict["nodes"].items():
G.add_node(node, **attrs)
# Add edges and their attributes
for source, targets in graph_dict["edges"].items():
for target, attrs in targets.items():
G.add_edge(source, target, **attrs)
return G
12 changes: 10 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def tree():


def check_graph_equality(g1, g2):
assert g1.nodes == g2.nodes
assert g1.edges == g2.edges
assert nx.is_isomorphic(g1, g2, node_match=lambda n1, n2: n1 == n2, edge_match=lambda e1, e2: e1 == e2)


def test_creation(X, adata, tree):
Expand Down Expand Up @@ -173,3 +172,12 @@ def test_copy(adata, tree):
assert np.array_equal(treedata.X, treedata_copy.X)
assert treedata.obst["tree"].nodes == treedata_copy.obst["tree"].nodes
assert treedata.obst["tree"].edges == treedata_copy.obst["tree"].edges


def test_transpose(adata, tree):
treedata = td.TreeData(adata, obst={"tree": tree})
treedata_transpose = treedata.transpose()
assert np.array_equal(treedata.X.T, treedata_transpose.X)
assert treedata.obst["tree"].nodes == treedata_transpose.vart["tree"].nodes
assert treedata_transpose.obst_keys() == []
assert np.array_equal(treedata.obs_names, treedata.T.obs_names)
Loading

0 comments on commit c9fd602

Please sign in to comment.