Skip to content

Commit

Permalink
Merge pull request #39 from YosefLab/resolve-dim-patch
Browse files Browse the repository at this point in the history
fixed _resolve_dim
  • Loading branch information
colganwi authored Nov 15, 2024
2 parents 9f9802d + a81b830 commit cd258b7
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ __pycache__/

# Prettier
/node_modules/

# Environment
environment.yml
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ and this project adheres to [Semantic Versioning][].

### Fixed

## [0.1.1] - 2024-11-25

### Added

- Axis in `td.concat` can now be specified with `obs` and `var` (#40)

### Changed

### Fixed

- Fixed `ImportError: cannot import name '_resolve_dim' from 'anndata._core.merge'` caused by anndata update (#40)

## [0.1.0] - 2024-09-27

### Added
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["hatchling"]

[project]
name = "treedata"
version = "0.1.0"
version = "0.1.1"
description = "anndata with trees"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -22,6 +22,7 @@ dependencies = [
"anndata",
"h5py",
"numpy",
"packaging",
"pandas",
"pathlib",
"pyarrow",
Expand Down
10 changes: 5 additions & 5 deletions src/treedata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import anndata as ad
import pandas as pd
from anndata._core.merge import _resolve_dim, resolve_merge_strategy
from anndata._core.merge import resolve_merge_strategy

from treedata._utils import combine_trees
from treedata._utils import _resolve_axis, combine_trees

from .treedata import TreeData

Expand All @@ -25,7 +25,7 @@
def concat(
tdatas: Collection[TreeData] | typing.Mapping[str, TreeData],
*,
axis: Literal[0, 1] = 0,
axis: Literal["obs", 0, "var", 1] = "obs",
join: Literal["inner", "outer"] = "inner",
merge: StrategiesLiteral | Callable | None = None,
uns_merge: StrategiesLiteral | Callable | None = None,
Expand Down Expand Up @@ -74,8 +74,8 @@ def concat(
Whether pairwise elements along the concatenated dimension should be included.
This is False by default, since the resulting arrays are often not meaningful.
"""
axis, dim = _resolve_dim(axis=axis)
alt_axis, alt_dim = _resolve_dim(axis=1 - axis)
axis, dim = _resolve_axis(axis)
alt_axis, alt_dim = _resolve_axis(axis=1 - axis)
merge = resolve_merge_strategy(merge)

# Check indices
Expand Down
24 changes: 18 additions & 6 deletions src/treedata/_core/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,21 @@
import h5py
import networkx as nx
import zarr
from packaging import version

from treedata._core.treedata import TreeData

ANDATA_VERSION = version.parse(ad.__version__)
USE_EXPERIMENTAL = ANDATA_VERSION < version.parse("0.11.0")


def _read_elem(elem):
"""Read an element from a store."""
if USE_EXPERIMENTAL:
return ad.experimental.read_elem(elem)
else:
return ad.io.read_elem(elem)


def _dict_to_digraph(graph_dict: dict) -> nx.DiGraph:
"""Convert a dictionary to a networkx.DiGraph."""
Expand Down Expand Up @@ -49,9 +61,9 @@ def _read_raw(f, backed):
d = {}
for k in ["obs", "var"]:
if f"raw/{k}" in f:
d[k] = ad.experimental.read_elem(f[f"raw/{k}"])
d[k] = _read_elem(f[f"raw/{k}"])
if not backed:
d["X"] = ad.experimental.read_elem(f["raw/X"])
d["X"] = _read_elem(f["raw/X"])
return d


Expand All @@ -64,23 +76,23 @@ def _read_tdata(f, filename, backed) -> dict:
backed = "r"
# Read X if not backed
if not backed:
d["X"] = ad.experimental.read_elem(f["X"])
d["X"] = _read_elem(f["X"])
else:
d.update({"filename": filename, "filemode": backed})
# Read standard elements
for k in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns", "label", "allow_overlap"]:
if k in f:
d[k] = ad.experimental.read_elem(f[k])
d[k] = _read_elem(f[k])
# Read raw
if "raw" in f:
d["raw"] = _read_raw(f, backed)
# Read axis tree elements
for k in ["obst", "vart"]:
if k in f:
d[k] = _parse_axis_trees(ad.experimental.read_elem(f[k]))
d[k] = _parse_axis_trees(_read_elem(f[k]))
# Read legacy treedata format
if "raw.treedata" in f:
d.update(_parse_legacy(json.loads(ad.experimental.read_elem(f["raw.treedata"]))))
d.update(_parse_legacy(json.loads(_read_elem(f["raw.treedata"]))))
return d


Expand Down
22 changes: 17 additions & 5 deletions src/treedata/_core/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
import numpy as np
import pandas as pd
import zarr
from packaging import version

from treedata._core.aligned_mapping import AxisTrees
from treedata._core.treedata import TreeData

ANDATA_VERSION = version.parse(ad.__version__)
USE_EXPERIMENTAL = ANDATA_VERSION < version.parse("0.11.0")


def _make_serializable(data: dict) -> dict:
"""Make a dictionary serializable."""
Expand All @@ -34,6 +38,14 @@ def _make_serializable(data: dict) -> dict:
return data


def _write_elem(f, k, elem, *, dataset_kwargs) -> None:
"""Write an element to a storage group using anndata encoding."""
if USE_EXPERIMENTAL:
ad.experimental.write_elem(f, k, elem, dataset_kwargs=dataset_kwargs)
else:
ad.io.write_elem(f, k, elem, dataset_kwargs=dataset_kwargs)


def _digraph_to_dict(G: nx.DiGraph) -> dict:
"""Convert a networkx.DiGraph to a dictionary."""
G = nx.DiGraph(G)
Expand Down Expand Up @@ -61,20 +73,20 @@ def _write_tdata(f, tdata, filename, **kwargs) -> None:
tdata.strings_to_categoricals()
# Write X if not backed
if not (tdata.isbacked and Path(tdata.filename) == Path(filename)):
ad.experimental.write_elem(f, "X", tdata.X, dataset_kwargs=kwargs)
_write_elem(f, "X", tdata.X, dataset_kwargs=kwargs)
# Write array elements
for key in ["obs", "var", "label", "allow_overlap"]:
ad.experimental.write_elem(f, key, getattr(tdata, key), dataset_kwargs=kwargs)
_write_elem(f, key, getattr(tdata, key), dataset_kwargs=kwargs)
# Write group elements
for key in ["obsm", "varm", "obsp", "varp", "layers", "uns"]:
ad.experimental.write_elem(f, key, dict(getattr(tdata, key)), dataset_kwargs=kwargs)
_write_elem(f, key, dict(getattr(tdata, key)), dataset_kwargs=kwargs)
# Write axis tree elements
for key in ["obst", "vart"]:
ad.experimental.write_elem(f, key, _serialize_axis_trees(getattr(tdata, key)), dataset_kwargs=kwargs)
_write_elem(f, key, _serialize_axis_trees(getattr(tdata, key)), dataset_kwargs=kwargs)
# Write raw
if tdata.raw is not None:
tdata.strings_to_categoricals(tdata.raw.var)
ad.experimental.write_elem(f, "raw", tdata.raw, dataset_kwargs=kwargs)
_write_elem(f, "raw", tdata.raw, dataset_kwargs=kwargs)
# Close the file
tdata.file.close()

Expand Down
12 changes: 12 additions & 0 deletions src/treedata/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from typing import Literal

import networkx as nx

Expand Down Expand Up @@ -34,3 +35,14 @@ def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph:

# The combined_tree now contains all nodes and edges from the subsets
return combined_tree


def _resolve_axis(
axis: Literal["obs", 0, "var", 1],
) -> tuple[Literal[0], Literal["obs"]] | tuple[Literal[1], Literal["var"]]:
"""Resolve axis argument."""
if axis in {0, "obs"}:
return (0, "obs")
if axis in {1, "var"}:
return (1, "var")
raise ValueError(f"`axis` must be either 0, 1, 'obs', or 'var', was {axis}")
2 changes: 1 addition & 1 deletion tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def tdata_list(tdata):

def test_concat(tdata_list):
# outer join
tdata = td.concat(tdata_list, axis=0, label="subset", join="outer")
tdata = td.concat(tdata_list, axis="obs", label="subset", join="outer")
print(tdata)
assert list(tdata.obs["subset"]) == ["0"] * 2 + ["1"] * 2 + ["2"] * 4
assert tdata.obst["0"].number_of_nodes() == 15
Expand Down

0 comments on commit cd258b7

Please sign in to comment.