Skip to content

Commit

Permalink
Merge _axisp_get_array and _axism_get_array (#3554)
Browse files Browse the repository at this point in the history
Rather than deconstructing and reconstruction annotation names (`obsm`, `varp`, etc.) create a single function `_get_annotation_layer` that does not use the `AxisName` enum.
  • Loading branch information
jp-dark authored Jan 16, 2025
1 parent 270b222 commit c70e41f
Showing 1 changed file with 38 additions and 43 deletions.
81 changes: 38 additions & 43 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from anndata import AnnData
from somacore import (
AxisQuery,
Collection,
DataFrame,
NDArray,
ReadIter,
SparseRead,
query,
Expand Down Expand Up @@ -342,29 +344,29 @@ def obsp(self, layer: str) -> SparseRead:
Lifecycle: maturing
"""
joinids = self._joinids.obs
return self._axisp_get_array(AxisName.OBS, layer).read((joinids, joinids))
return self._get_annotation_layer("obsp", layer).read((joinids, joinids))

def varp(self, layer: str) -> SparseRead:
"""Returns a ``varp`` layer as a sparse read.
Lifecycle: maturing
"""
joinids = self._joinids.var
return self._axisp_get_array(AxisName.VAR, layer).read((joinids, joinids))
return self._get_annotation_layer("varp", layer).read((joinids, joinids))

def obsm(self, layer: str) -> SparseRead:
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(AxisName.OBS, layer).read(
return self._get_annotation_layer("obsm", layer).read(
(self._joinids.obs, slice(None))
)

def varm(self, layer: str) -> SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(AxisName.VAR, layer).read(
return self._get_annotation_layer("varm", layer).read(
(self._joinids.var, slice(None))
)

Expand Down Expand Up @@ -823,52 +825,39 @@ def _read_axis_dataframe(
arrow_table = arrow_table.drop(["soma_joinid"])
return arrow_table

def _axisp_get_array(
self,
axis: AxisName,
layer: str,
def _get_annotation_layer(
self, annotation_name: str, layer_name: str
) -> SparseNDArray:
p_name = f"{axis.value}p"
try:
ms = self._ms
axisp = ms.obsp if axis.value == "obs" else ms.varp
except (AttributeError, KeyError):
raise ValueError(f"Measurement does not contain {p_name} data")
"""Helper function to make error messages consistent.
Args:
annotation_name:
Name of the annotation (e.g. obsm, varp).
layer_name:
Name of the layer.
"""
try:
axisp_layer = axisp[layer]
coll: Collection[NDArray] = self._ms[annotation_name] # type: ignore
except KeyError:
raise ValueError(f"layer {layer!r} is not available in {p_name}")
if not isinstance(axisp_layer, SparseNDArray):
raise ValueError(f"Measurement does not contain {annotation_name!r} data.")
if not isinstance(coll, Collection):
raise TypeError(
f"Unexpected SOMA type {type(axisp_layer).__name__}"
f" stored in {p_name} layer {layer!r}"
f"Unexpected SOMA type {type(coll).__name__} for "
f"{annotation_name!r}."
)

return axisp_layer

def _axism_get_array(
self,
axis: AxisName,
layer: str,
) -> SparseNDArray:
m_name = f"{axis.value}m"

try:
ms = self._ms
axism = ms.obsm if axis.value == "obs" else ms.varm
except (AttributeError, KeyError):
raise ValueError(f"Measurement does not contain {m_name} data")

try:
axism_layer = axism[layer]
layer = coll[layer_name]
except KeyError:
raise ValueError(f"layer {layer!r} is not available in {m_name}")

if not isinstance(axism_layer, SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")

return axism_layer
raise ValueError(
f"layer {layer_name!r} is not available in {annotation_name!r}."
)
if not isinstance(layer, SparseNDArray):
raise TypeError(
f"Unexpected SOMA type {type(layer).__name__} stored in "
f"{annotation_name!r} layer {layer_name!r}."
)
return layer

def _convert_to_ndarray(
self, axis: AxisName, table: pa.Table, n_row: int, n_col: int
Expand All @@ -892,8 +881,13 @@ def _axisp_inner_sparray(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
annotation_name = f"{axis.value}p"
return _read_as_csr(
self._axisp_get_array(axis, layer), joinids, joinids, indexer, indexer
self._get_annotation_layer(annotation_name, layer),
joinids,
joinids,
indexer,
indexer,
)

def _axism_inner_ndarray(
Expand All @@ -902,8 +896,9 @@ def _axism_inner_ndarray(
layer: str,
) -> npt.NDArray[np.float32]:
joinids = axis.getattr_from(self._joinids)
annotation_name = f"{axis.value}m"
table = (
self._axism_get_array(axis, layer)
self._get_annotation_layer(annotation_name, layer)
.read((joinids, slice(None)))
.tables()
.concat()
Expand Down

0 comments on commit c70e41f

Please sign in to comment.