diff --git a/.gitignore b/.gitignore index 8ff1664..7d0122c 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ __pycache__/ # Prettier /node_modules/ + +# Environment +environment.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 026b078..2847c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,18 @@ and this project adheres to [Semantic Versioning][]. ### Fixed +## [0.1.1] - 2024-11-25 + +### Added + +- Axis in `td.concat` can now be specified with `obs` and `var` (#40) + +### Changed + +### Fixed + +- Fixed `ImportError: cannot import name '_resolve_dim' from 'anndata._core.merge'` caused by anndata update (#40) + ## [0.1.0] - 2024-09-27 ### Added diff --git a/pyproject.toml b/pyproject.toml index 6edec31..bd19974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = ["hatchling"] [project] name = "treedata" -version = "0.1.0" +version = "0.1.1" description = "anndata with trees" readme = "README.md" requires-python = ">=3.10" @@ -22,6 +22,7 @@ dependencies = [ "anndata", "h5py", "numpy", + "packaging", "pandas", "pathlib", "pyarrow", diff --git a/src/treedata/_core/merge.py b/src/treedata/_core/merge.py index 35ae62b..a81a688 100755 --- a/src/treedata/_core/merge.py +++ b/src/treedata/_core/merge.py @@ -13,9 +13,9 @@ import anndata as ad import pandas as pd -from anndata._core.merge import _resolve_dim, resolve_merge_strategy +from anndata._core.merge import resolve_merge_strategy -from treedata._utils import combine_trees +from treedata._utils import _resolve_axis, combine_trees from .treedata import TreeData @@ -25,7 +25,7 @@ def concat( tdatas: Collection[TreeData] | typing.Mapping[str, TreeData], *, - axis: Literal[0, 1] = 0, + axis: Literal["obs", 0, "var", 1] = "obs", join: Literal["inner", "outer"] = "inner", merge: StrategiesLiteral | Callable | None = None, uns_merge: StrategiesLiteral | Callable | None = None, @@ -74,8 +74,8 @@ def concat( Whether pairwise elements along the concatenated dimension should be included. This is False by default, since the resulting arrays are often not meaningful. """ - axis, dim = _resolve_dim(axis=axis) - alt_axis, alt_dim = _resolve_dim(axis=1 - axis) + axis, dim = _resolve_axis(axis) + alt_axis, alt_dim = _resolve_axis(axis=1 - axis) merge = resolve_merge_strategy(merge) # Check indices diff --git a/src/treedata/_core/read.py b/src/treedata/_core/read.py index 6a804c8..3bd9466 100755 --- a/src/treedata/_core/read.py +++ b/src/treedata/_core/read.py @@ -11,9 +11,21 @@ import h5py import networkx as nx import zarr +from packaging import version from treedata._core.treedata import TreeData +ANDATA_VERSION = version.parse(ad.__version__) +USE_EXPERIMENTAL = ANDATA_VERSION < version.parse("0.11.0") + + +def _read_elem(elem): + """Read an element from a store.""" + if USE_EXPERIMENTAL: + return ad.experimental.read_elem(elem) + else: + return ad.io.read_elem(elem) + def _dict_to_digraph(graph_dict: dict) -> nx.DiGraph: """Convert a dictionary to a networkx.DiGraph.""" @@ -49,9 +61,9 @@ def _read_raw(f, backed): d = {} for k in ["obs", "var"]: if f"raw/{k}" in f: - d[k] = ad.experimental.read_elem(f[f"raw/{k}"]) + d[k] = _read_elem(f[f"raw/{k}"]) if not backed: - d["X"] = ad.experimental.read_elem(f["raw/X"]) + d["X"] = _read_elem(f["raw/X"]) return d @@ -64,23 +76,23 @@ def _read_tdata(f, filename, backed) -> dict: backed = "r" # Read X if not backed if not backed: - d["X"] = ad.experimental.read_elem(f["X"]) + d["X"] = _read_elem(f["X"]) else: d.update({"filename": filename, "filemode": backed}) # Read standard elements for k in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns", "label", "allow_overlap"]: if k in f: - d[k] = ad.experimental.read_elem(f[k]) + d[k] = _read_elem(f[k]) # Read raw if "raw" in f: d["raw"] = _read_raw(f, backed) # Read axis tree elements for k in ["obst", "vart"]: if k in f: - d[k] = _parse_axis_trees(ad.experimental.read_elem(f[k])) + d[k] = _parse_axis_trees(_read_elem(f[k])) # Read legacy treedata format if "raw.treedata" in f: - d.update(_parse_legacy(json.loads(ad.experimental.read_elem(f["raw.treedata"])))) + d.update(_parse_legacy(json.loads(_read_elem(f["raw.treedata"])))) return d diff --git a/src/treedata/_core/write.py b/src/treedata/_core/write.py index 39548b2..c5e6c1e 100755 --- a/src/treedata/_core/write.py +++ b/src/treedata/_core/write.py @@ -13,10 +13,14 @@ import numpy as np import pandas as pd import zarr +from packaging import version from treedata._core.aligned_mapping import AxisTrees from treedata._core.treedata import TreeData +ANDATA_VERSION = version.parse(ad.__version__) +USE_EXPERIMENTAL = ANDATA_VERSION < version.parse("0.11.0") + def _make_serializable(data: dict) -> dict: """Make a dictionary serializable.""" @@ -34,6 +38,14 @@ def _make_serializable(data: dict) -> dict: return data +def _write_elem(f, k, elem, *, dataset_kwargs) -> None: + """Write an element to a storage group using anndata encoding.""" + if USE_EXPERIMENTAL: + ad.experimental.write_elem(f, k, elem, dataset_kwargs=dataset_kwargs) + else: + ad.io.write_elem(f, k, elem, dataset_kwargs=dataset_kwargs) + + def _digraph_to_dict(G: nx.DiGraph) -> dict: """Convert a networkx.DiGraph to a dictionary.""" G = nx.DiGraph(G) @@ -61,20 +73,20 @@ def _write_tdata(f, tdata, filename, **kwargs) -> None: tdata.strings_to_categoricals() # Write X if not backed if not (tdata.isbacked and Path(tdata.filename) == Path(filename)): - ad.experimental.write_elem(f, "X", tdata.X, dataset_kwargs=kwargs) + _write_elem(f, "X", tdata.X, dataset_kwargs=kwargs) # Write array elements for key in ["obs", "var", "label", "allow_overlap"]: - ad.experimental.write_elem(f, key, getattr(tdata, key), dataset_kwargs=kwargs) + _write_elem(f, key, getattr(tdata, key), dataset_kwargs=kwargs) # Write group elements for key in ["obsm", "varm", "obsp", "varp", "layers", "uns"]: - ad.experimental.write_elem(f, key, dict(getattr(tdata, key)), dataset_kwargs=kwargs) + _write_elem(f, key, dict(getattr(tdata, key)), dataset_kwargs=kwargs) # Write axis tree elements for key in ["obst", "vart"]: - ad.experimental.write_elem(f, key, _serialize_axis_trees(getattr(tdata, key)), dataset_kwargs=kwargs) + _write_elem(f, key, _serialize_axis_trees(getattr(tdata, key)), dataset_kwargs=kwargs) # Write raw if tdata.raw is not None: tdata.strings_to_categoricals(tdata.raw.var) - ad.experimental.write_elem(f, "raw", tdata.raw, dataset_kwargs=kwargs) + _write_elem(f, "raw", tdata.raw, dataset_kwargs=kwargs) # Close the file tdata.file.close() diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index 9f360fd..456181b 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -1,4 +1,5 @@ from collections import deque +from typing import Literal import networkx as nx @@ -34,3 +35,14 @@ def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph: # The combined_tree now contains all nodes and edges from the subsets return combined_tree + + +def _resolve_axis( + axis: Literal["obs", 0, "var", 1], +) -> tuple[Literal[0], Literal["obs"]] | tuple[Literal[1], Literal["var"]]: + """Resolve axis argument.""" + if axis in {0, "obs"}: + return (0, "obs") + if axis in {1, "var"}: + return (1, "var") + raise ValueError(f"`axis` must be either 0, 1, 'obs', or 'var', was {axis}") diff --git a/tests/test_merge.py b/tests/test_merge.py index 11d7a1b..c86d699 100755 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -35,7 +35,7 @@ def tdata_list(tdata): def test_concat(tdata_list): # outer join - tdata = td.concat(tdata_list, axis=0, label="subset", join="outer") + tdata = td.concat(tdata_list, axis="obs", label="subset", join="outer") print(tdata) assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4 assert tdata.obst["0"].number_of_nodes() == 15