Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify X and layers #1707

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import MutableMapping, Sequence
from copy import copy
from dataclasses import dataclass
from types import NoneType
from typing import TYPE_CHECKING, Generic, TypeVar

import numpy as np
Expand Down Expand Up @@ -38,12 +39,13 @@
# TODO: pd.DataFrame only allowed in AxisArrays?
Value = pd.DataFrame | spmatrix | np.ndarray

K = TypeVar("K", str, str | None)
P = TypeVar("P", bound="AlignedMappingBase")
"""Parent mapping an AlignedView is based on."""
I = TypeVar("I", OneDIdx, TwoDIdx)


class AlignedMappingBase(MutableMapping[str, Value], ABC):
class AlignedMappingBase(MutableMapping[K, Value], ABC, Generic[K]):
"""\
An abstract base class for Mappings containing array-like values aligned
to either one or both AnnData axes.
Expand All @@ -61,13 +63,13 @@
_parent: AnnData | Raw
"""The parent object that this mapping is aligned to."""

def __repr__(self):
return f"{type(self).__name__} with keys: {', '.join(self.keys())}"
def __repr__(self) -> str:
return f"{type(self).__name__} with keys: {', '.join(map(repr, self.keys()))}"

def _ipython_key_completions_(self) -> list[str]:
def _ipython_key_completions_(self) -> list[K]:
return list(self.keys())

def _validate_value(self, val: Value, key: str) -> Value:
def _validate_value(self, val: Value, key: K) -> Value:
"""Raises an error if value is invalid"""
if isinstance(val, AwkArray):
warn_once(
Expand Down Expand Up @@ -117,13 +119,14 @@
def parent(self) -> AnnData | Raw:
return self._parent

def copy(self) -> dict[str, Value]:
def copy(self) -> dict[K, Value]:
# Shallow copy for awkward array since their buffers are immutable
return {
k: copy(v) if isinstance(v, AwkArray) else v.copy() for k, v in self.items()
k: copy(v) if isinstance(v, AwkArray | NoneType) else v.copy()
for k, v in self.items()
}

def _view(self, parent: AnnData, subset_idx: I) -> AlignedView[Self, I]:
def _view(self, parent: AnnData, subset_idx: I) -> AlignedView[K, Self, I]:
"""Returns a subset copy-on-write view of the object."""
return self._view_class(self, parent, subset_idx)

Expand All @@ -132,7 +135,7 @@
return dict(self)


class AlignedView(AlignedMappingBase, Generic[P, I]):
class AlignedView(AlignedMappingBase[K], Generic[K, P, I]):
is_view: ClassVar[Literal[True]] = True

# override docstring
Expand All @@ -156,13 +159,15 @@
# LayersBase has no _axis, the rest does
self._axis = parent_mapping._axis # type: ignore

def __getitem__(self, key: str) -> Value:
def __getitem__(self, key: K) -> Value:
if self.parent_mapping[key] is None:
return None
return as_view(
_subset(self.parent_mapping[key], self.subset_idx),
ElementRef(self.parent, self.attrname, (key,)),
)

def __setitem__(self, key: str, value: Value) -> None:
def __setitem__(self, key: K, value: Value) -> None:
value = self._validate_value(value, key) # Validate before mutating
warnings.warn(
f"Setting element `.{self.attrname}['{key}']` of view, "
Expand All @@ -171,9 +176,12 @@
stacklevel=2,
)
with view_update(self.parent, self.attrname, ()) as new_mapping:
new_mapping[key] = value
if value is None:
del new_mapping[key]

Check warning on line 180 in src/anndata/_core/aligned_mapping.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/aligned_mapping.py#L180

Added line #L180 was not covered by tests
else:
new_mapping[key] = value

def __delitem__(self, key: str) -> None:
def __delitem__(self, key: K) -> None:
if key not in self:
raise KeyError(
"'{key!r}' not found in view of {self.attrname}"
Expand All @@ -187,49 +195,52 @@
with view_update(self.parent, self.attrname, ()) as new_mapping:
del new_mapping[key]

def __contains__(self, key: str) -> bool:
def __contains__(self, key: K) -> bool:
return key in self.parent_mapping

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[K]:
return iter(self.parent_mapping)

def __len__(self) -> int:
return len(self.parent_mapping)


class AlignedActual(AlignedMappingBase):
class AlignedActual(AlignedMappingBase[K], Generic[K]):
is_view: ClassVar[Literal[False]] = False

_data: MutableMapping[str, Value]
_data: MutableMapping[K, Value]
"""Underlying mapping to the data"""

def __init__(self, parent: AnnData | Raw, *, store: MutableMapping[str, Value]):
def __init__(self, parent: AnnData | Raw, *, store: MutableMapping[K, Value]):
self._parent = parent
self._data = store
for k, v in self._data.items():
if v is None:
continue
self._data[k] = self._validate_value(v, k)

def __getitem__(self, key: str) -> Value:
def __getitem__(self, key: K) -> Value:
return self._data[key]

def __setitem__(self, key: str, value: Value):
value = self._validate_value(value, key)
def __setitem__(self, key: K, value: Value):
if value is not None:
value = self._validate_value(value, key)
self._data[key] = value

def __contains__(self, key: str) -> bool:
def __contains__(self, key: K) -> bool:
return key in self._data

def __delitem__(self, key: str):
def __delitem__(self, key: K):
del self._data[key]

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[K]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)


class AxisArraysBase(AlignedMappingBase):
class AxisArraysBase(AlignedMappingBase[str]):
"""\
Mapping of key→array-like,
where array-like is aligned to an axis of parent AnnData.
Expand Down Expand Up @@ -283,7 +294,7 @@
return (self.parent.obs_names, self.parent.var_names)[self._axis]


class AxisArrays(AlignedActual, AxisArraysBase):
class AxisArrays(AlignedActual[str], AxisArraysBase):
def __init__(
self,
parent: AnnData | Raw,
Expand All @@ -297,15 +308,15 @@
super().__init__(parent, store=store)


class AxisArraysView(AlignedView[AxisArraysBase, OneDIdx], AxisArraysBase):
class AxisArraysView(AlignedView[str, AxisArraysBase, OneDIdx], AxisArraysBase):
pass


AxisArraysBase._view_class = AxisArraysView
AxisArraysBase._actual_class = AxisArrays


class LayersBase(AlignedMappingBase):
class LayersBase(AlignedMappingBase[str | None]):
"""\
Mapping of key: array-like, where array-like is aligned to both axes of the
parent anndata.
Expand All @@ -316,19 +327,19 @@
axes: ClassVar[tuple[Literal[0], Literal[1]]] = (0, 1)


class Layers(AlignedActual, LayersBase):
class Layers(AlignedActual[str | None], LayersBase):
pass


class LayersView(AlignedView[LayersBase, TwoDIdx], LayersBase):
class LayersView(AlignedView[str | None, LayersBase, TwoDIdx], LayersBase):
pass


LayersBase._view_class = LayersView
LayersBase._actual_class = Layers


class PairwiseArraysBase(AlignedMappingBase):
class PairwiseArraysBase(AlignedMappingBase[str]):
"""\
Mapping of key: array-like, where both axes of array-like are aligned to
one axis of the parent anndata.
Expand All @@ -354,7 +365,7 @@
return self._dimnames[self._axis]


class PairwiseArrays(AlignedActual, PairwiseArraysBase):
class PairwiseArrays(AlignedActual[str], PairwiseArraysBase):
def __init__(
self,
parent: AnnData,
Expand All @@ -368,7 +379,9 @@
super().__init__(parent, store=store)


class PairwiseArraysView(AlignedView[PairwiseArraysBase, OneDIdx], PairwiseArraysBase):
class PairwiseArraysView(
AlignedView[str, PairwiseArraysBase, OneDIdx], PairwiseArraysBase
):
pass


Expand All @@ -389,7 +402,7 @@


@dataclass
class AlignedMappingProperty(property, Generic[T]):
class AlignedMappingProperty(property, Generic[K, T]):
"""A :class:`property` that creates an ephemeral AlignedMapping.

The actual data is stored as `f'_{self.name}'` in the parent object.
Expand All @@ -402,7 +415,7 @@
axis: Literal[0, 1] | None = None
"""Axis of the parent to align to."""

def construct(self, obj: AnnData, *, store: MutableMapping[str, Value]) -> T:
def construct(self, obj: AnnData, *, store: MutableMapping[K, Value]) -> T:
if self.axis is None:
return self.cls(obj, store=store)
return self.cls(obj, axis=self.axis, store=store)
Expand All @@ -429,13 +442,14 @@
return parent._view(obj, tuple(idxs[ax] for ax in parent.axes))

def __set__(
self, obj: AnnData, value: Mapping[str, Value] | Iterable[tuple[str, Value]]
self, obj: AnnData, value: Mapping[K, Value] | Iterable[tuple[K, Value]] | None
) -> None:
value = convert_to_dict(value)
_ = self.construct(obj, store=value) # Validate
if obj.is_view:
obj._init_as_actual(obj.copy())
setattr(obj, f"_{self.name}", value)

def __delete__(self, obj) -> None:
setattr(obj, self.name, dict())
def __delete__(self, obj: AnnData) -> None:
new = {None: x} if (x := getattr(obj, self.name).get(None)) is not None else {}
setattr(obj, self.name, new)
Loading
Loading