diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54cbc5314b..2807cf368a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,8 @@ repos: # Pandas 2.x types (e.g. `pd.Series[Any]`). See `_types.py` or https://github.com/single-cell-data/TileDB-SOMA/issues/2839 # for more info. - "pandas-stubs>=2" - - "somacore==1.0.23" + # Temporary, for PR: see https://github.com/single-cell-data/SOMA/pull/244 + - "git+https://github.com/single-cell-data/soma@9e81f07" - types-setuptools args: ["--config-file=apis/python/pyproject.toml", "apis/python/src", "apis/python/devtools"] pass_filenames: false diff --git a/apis/python/setup.py b/apis/python/setup.py index 3742b7a6cd..e014d0334d 100644 --- a/apis/python/setup.py +++ b/apis/python/setup.py @@ -336,8 +336,8 @@ def run(self): "pyarrow", "scanpy>=1.9.2", "scipy", - # Note: the somacore version is in .pre-commit-config.yaml too - "somacore==1.0.23", + # Temporary, for PR: see https://github.com/single-cell-data/SOMA/pull/244 + "somacore @ git+https://github.com/single-cell-data/soma@rw/abcs", "typing-extensions", # Note "-" even though `import typing_extensions` ], extras_require={ diff --git a/apis/python/src/tiledbsoma/_experiment.py b/apis/python/src/tiledbsoma/_experiment.py index 5a57662761..dcd297d29e 100644 --- a/apis/python/src/tiledbsoma/_experiment.py +++ b/apis/python/src/tiledbsoma/_experiment.py @@ -9,13 +9,13 @@ from typing import Optional from somacore import experiment, query -from typing_extensions import Self from . import _tdb_handles from ._collection import Collection, CollectionBase from ._dataframe import DataFrame from ._indexer import IntIndexer from ._measurement import Measurement +from ._query import ExperimentAxisQuery from ._scene import Scene from ._soma_object import AnySOMAObject @@ -83,13 +83,11 @@ def axis_query( # type: ignore *, obs_query: Optional[query.AxisQuery] = None, var_query: Optional[query.AxisQuery] = None, - ) -> query.ExperimentAxisQuery[Self]: # type: ignore + ) -> ExperimentAxisQuery: """Creates an axis query over this experiment. Lifecycle: Maturing. """ - # mypy doesn't quite understand descriptors so it issues a spurious - # error here. - return query.ExperimentAxisQuery( # type: ignore + return ExperimentAxisQuery( self, measurement_name, obs_query=obs_query or query.AxisQuery(), diff --git a/apis/python/src/tiledbsoma/_indexer.py b/apis/python/src/tiledbsoma/_indexer.py index 4f39b0b042..2dc92f37c2 100644 --- a/apis/python/src/tiledbsoma/_indexer.py +++ b/apis/python/src/tiledbsoma/_indexer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import List, Optional, Union import numpy as np import numpy.typing as npt @@ -11,9 +11,7 @@ from tiledbsoma import pytiledbsoma as clib from ._types import PDSeries - -if TYPE_CHECKING: - from .options import SOMATileDBContext +from .options import SOMATileDBContext IndexerDataType = Union[ npt.NDArray[np.int64], @@ -27,7 +25,7 @@ def tiledbsoma_build_index( - data: IndexerDataType, *, context: Optional["SOMATileDBContext"] = None + data: IndexerDataType, *, context: Optional[SOMATileDBContext] = None ) -> IndexLike: """Initialize re-indexer for provided indices (deprecated). @@ -54,7 +52,7 @@ class IntIndexer: """ def __init__( - self, data: IndexerDataType, *, context: Optional["SOMATileDBContext"] = None + self, data: IndexerDataType, *, context: Optional[SOMATileDBContext] = None ): """Initialize re-indexer for provided indices. @@ -73,7 +71,7 @@ def __init__( ) self._reindexer.map_locations(data) - def get_indexer(self, target: IndexerDataType) -> Any: + def get_indexer(self, target: IndexerDataType) -> npt.NDArray[np.intp]: """Compute underlying indices of index for target data. Compatible with Pandas' Index.get_indexer method. @@ -81,7 +79,7 @@ def get_indexer(self, target: IndexerDataType) -> Any: Args: target: Data to return re-index data for. """ - return ( + return ( # type: ignore[no-any-return] self._reindexer.get_indexer_pyarrow(target) if isinstance(target, (pa.Array, pa.ChunkedArray)) else self._reindexer.get_indexer_general(target) diff --git a/apis/python/src/tiledbsoma/_query.py b/apis/python/src/tiledbsoma/_query.py new file mode 100644 index 0000000000..888fcaf1ab --- /dev/null +++ b/apis/python/src/tiledbsoma/_query.py @@ -0,0 +1,850 @@ +# Copyright (c) 2021-2023 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2023 TileDB, Inc. +# +# Licensed under the MIT License. + +"""Implementation of a SOMA Experiment. +""" +import enum +from concurrent.futures import ThreadPoolExecutor +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Literal, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + cast, + overload, +) + +import attrs +import numpy as np +import numpy.typing as npt +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pacomp +import scipy.sparse as sp +from anndata import AnnData +from somacore import ( + AxisQuery, + DataFrame, + ReadIter, + SparseRead, + query, +) +from somacore.data import _RO_AUTO +from somacore.options import ( + BatchSize, + PlatformConfig, + ReadPartitions, + ResultOrder, + ResultOrderStr, +) +from somacore.query import _fast_csr +from somacore.query.query import ( + AxisColumnNames, + Numpyable, +) +from somacore.query.types import IndexFactory, IndexLike +from typing_extensions import Self + +if TYPE_CHECKING: + from ._experiment import Experiment +from ._measurement import Measurement +from ._sparse_nd_array import SparseNDArray + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +class _HasObsVar(Protocol[_T_co]): + """Something which has an ``obs`` and ``var`` field. + + Used to give nicer type inference in :meth:`Axis.getattr_from`. + """ + + @property + def obs(self) -> _T_co: ... + + @property + def var(self) -> _T_co: ... + + +class Axis(enum.Enum): + OBS = "obs" + VAR = "var" + + @property + def value(self) -> Literal["obs", "var"]: + return super().value # type: ignore[no-any-return] + + @overload + def getattr_from(self, __source: _HasObsVar[_T]) -> _T: ... + + @overload + def getattr_from( + self, __source: Any, *, pre: Literal[""], suf: Literal[""] + ) -> object: ... + + @overload + def getattr_from( + self, __source: Any, *, pre: str = ..., suf: str = ... + ) -> object: ... + + def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object: + """Equivalent to ``something.
``."""
+        return getattr(__source, pre + self.value + suf)
+
+    def getitem_from(
+        self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
+    ) -> _T:
+        """Equivalent to ``something[pre + "obs"/"var" + suf]``."""
+        return __source[pre + self.value + suf]
+
+
+@attrs.define
+class AxisIndexer(query.AxisIndexer):
+    """
+    Given a query, provides index-building services for obs/var axis.
+
+    Lifecycle: maturing
+    """
+
+    query: "ExperimentAxisQuery"
+    _index_factory: IndexFactory
+    _cached_obs: Optional[IndexLike] = None
+    _cached_var: Optional[IndexLike] = None
+
+    @property
+    def _obs_index(self) -> IndexLike:
+        """Private. Return an index for the ``obs`` axis."""
+        if self._cached_obs is None:
+            self._cached_obs = self._index_factory(self.query.obs_joinids().to_numpy())
+        return self._cached_obs
+
+    @property
+    def _var_index(self) -> IndexLike:
+        """Private. Return an index for the ``var`` axis."""
+        if self._cached_var is None:
+            self._cached_var = self._index_factory(self.query.var_joinids().to_numpy())
+        return self._cached_var
+
+    def by_obs(self, coords: Numpyable) -> npt.NDArray[np.intp]:
+        """Reindex the coords (soma_joinids) over the ``obs`` axis."""
+        return self._obs_index.get_indexer(_to_numpy(coords))
+
+    def by_var(self, coords: Numpyable) -> npt.NDArray[np.intp]:
+        """Reindex for the coords (soma_joinids) over the ``var`` axis."""
+        return self._var_index.get_indexer(_to_numpy(coords))
+
+
+def _to_numpy(it: Numpyable) -> npt.NDArray[np.int64]:
+    if isinstance(it, np.ndarray):
+        return it
+    return it.to_numpy()  # type: ignore[no-any-return]
+
+
+@attrs.define(frozen=True)
+class AxisQueryResult:
+    """The result of running :meth:`ExperimentAxisQuery.read`. Private."""
+
+    obs: pd.DataFrame
+    """Experiment.obs query slice, as a pandas DataFrame"""
+    var: pd.DataFrame
+    """Experiment.ms[...].var query slice, as a pandas DataFrame"""
+    X: sp.csr_matrix
+    """Experiment.ms[...].X[...] query slice, as a SciPy sparse.csr_matrix """
+    X_layers: Dict[str, sp.csr_matrix] = attrs.field(factory=dict)
+    """Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
+    obsm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.obsm query slice, as a numpy ndarray"""
+    obsp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.obsp query slice, as a numpy ndarray"""
+    varm: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.varm query slice, as a numpy ndarray"""
+    varp: Dict[str, npt.NDArray[Any]] = attrs.field(factory=dict)
+    """Experiment.varp query slice, as a numpy ndarray"""
+
+    def to_anndata(self) -> AnnData:
+        return AnnData(
+            X=self.X,
+            obs=self.obs,
+            var=self.var,
+            obsm=(self.obsm or None),
+            obsp=(self.obsp or None),
+            varm=(self.varm or None),
+            varp=(self.varp or None),
+            layers=(self.X_layers or None),
+        )
+
+
+class ExperimentAxisQuery:
+    """Axis-based query against a SOMA Experiment.
+
+    ExperimentAxisQuery allows easy selection and extraction of data from a
+    single :class:`Measurement` in an :class:`Experiment`, by obs/var (axis) coordinates
+    and/or value filter.
+
+    The primary use for this class is slicing :class:`Experiment` ``X`` layers by obs or
+    var value and/or coordinates. Slicing on :class:`SparseNDArray` ``X`` matrices is
+    supported; :class:`DenseNDArray` is not supported at this time.
+
+    IMPORTANT: this class is not thread-safe.
+
+    IMPORTANT: this query class assumes it can store the full result of both
+    axis dataframe queries in memory, and only provides incremental access to
+    the underlying X NDArray. API features such as ``n_obs`` and ``n_vars``
+    codify this in the API.
+
+    IMPORTANT: you must call ``close()`` on any instance of this class to
+    release underlying resources. The ExperimentAxisQuery is a context manager,
+    and it is recommended that you use the following pattern to make this easy
+    and safe::
+
+        with ExperimentAxisQuery(...) as query:
+            ...
+
+    This base query implementation is designed to work against any SOMA
+    implementation that fulfills the basic APIs. A SOMA implementation may
+    include a custom query implementation optimized for its own use.
+
+    Lifecycle: maturing
+    """
+
+    def __init__(
+        self,
+        experiment: "Experiment",
+        measurement_name: str,
+        *,
+        obs_query: AxisQuery = AxisQuery(),
+        var_query: AxisQuery = AxisQuery(),
+        index_factory: IndexFactory = pd.Index,
+    ):
+        if measurement_name not in experiment.ms:
+            raise ValueError("Measurement does not exist in the experiment")
+
+        # Users often like to pass `foo=None` and we should let them
+        obs_query = obs_query or AxisQuery()
+        var_query = var_query or AxisQuery()
+
+        self.experiment = experiment
+        self.measurement_name = measurement_name
+
+        self._matrix_axis_query = MatrixAxisQuery(obs=obs_query, var=var_query)
+        self._joinids = JoinIDCache(self)
+        self._indexer = AxisIndexer(
+            self,
+            index_factory=index_factory,
+        )
+        self._index_factory = index_factory
+        self._threadpool_: Optional[ThreadPoolExecutor] = None
+
+    def obs(
+        self,
+        *,
+        column_names: Optional[Sequence[str]] = None,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> ReadIter[pa.Table]:
+        """Returns ``obs`` as an `Arrow table
+        `_
+        iterator.
+
+        Lifecycle: maturing
+        """
+        obs_query = self._matrix_axis_query.obs
+        return self._obs_df.read(
+            obs_query.coords,
+            value_filter=obs_query.value_filter,
+            column_names=column_names,
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def var(
+        self,
+        *,
+        column_names: Optional[Sequence[str]] = None,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> ReadIter[pa.Table]:
+        """Returns ``var`` as an `Arrow table
+        `_
+        iterator.
+
+        Lifecycle: maturing
+        """
+        var_query = self._matrix_axis_query.var
+        return self._var_df.read(
+            var_query.coords,
+            value_filter=var_query.value_filter,
+            column_names=column_names,
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def obs_joinids(self) -> pa.IntegerArray:
+        """Returns ``obs`` ``soma_joinids`` as an Arrow array.
+
+        Lifecycle: maturing
+        """
+        return self._joinids.obs
+
+    def var_joinids(self) -> pa.IntegerArray:
+        """Returns ``var`` ``soma_joinids`` as an Arrow array.
+
+        Lifecycle: maturing
+        """
+        return self._joinids.var
+
+    @property
+    def n_obs(self) -> int:
+        """The number of ``obs`` axis query results.
+
+        Lifecycle: maturing
+        """
+        return len(self.obs_joinids())
+
+    @property
+    def n_vars(self) -> int:
+        """The number of ``var`` axis query results.
+
+        Lifecycle: maturing
+        """
+        return len(self.var_joinids())
+
+    @property
+    def indexer(self) -> AxisIndexer:
+        """A ``soma_joinid`` indexer for both ``obs`` and ``var`` axes.
+
+        Lifecycle: maturing
+        """
+        return self._indexer
+
+    def X(
+        self,
+        layer_name: str,
+        *,
+        batch_size: BatchSize = BatchSize(),
+        partitions: Optional[ReadPartitions] = None,
+        result_order: ResultOrderStr = _RO_AUTO,
+        platform_config: Optional[PlatformConfig] = None,
+    ) -> SparseRead:
+        """Returns an ``X`` layer as a sparse read.
+
+        Args:
+            layer_name: The X layer name to return.
+            batch_size: The size of batches that should be returned from a read.
+                See :class:`BatchSize` for details.
+            partitions: Specifies that this is part of a partitioned read,
+                and which partition to include, if present.
+            result_order: the order to return results, specified as a
+                :class:`~ResultOrder` or its string value.
+
+        Lifecycle: maturing
+        """
+        try:
+            x_layer = self._ms.X[layer_name]
+        except KeyError as ke:
+            raise KeyError(f"{layer_name} is not present in X") from ke
+        if not isinstance(x_layer, SparseNDArray):
+            raise TypeError("X layers may only be sparse arrays")
+
+        self._joinids.preload(self._threadpool)
+        return x_layer.read(
+            (self._joinids.obs, self._joinids.var),
+            batch_size=batch_size,
+            partitions=partitions,
+            result_order=result_order,
+            platform_config=platform_config,
+        )
+
+    def obsp(self, layer: str) -> SparseRead:
+        """Returns an ``obsp`` layer as a sparse read.
+
+        Lifecycle: maturing
+        """
+        return self._axisp_inner(Axis.OBS, layer)
+
+    def varp(self, layer: str) -> SparseRead:
+        """Returns a ``varp`` layer as a sparse read.
+
+        Lifecycle: maturing
+        """
+        return self._axisp_inner(Axis.VAR, layer)
+
+    def obsm(self, layer: str) -> SparseRead:
+        """Returns an ``obsm`` layer as a sparse read.
+        Lifecycle: maturing
+        """
+        return self._axism_inner(Axis.OBS, layer)
+
+    def varm(self, layer: str) -> SparseRead:
+        """Returns a ``varm`` layer as a sparse read.
+        Lifecycle: maturing
+        """
+        return self._axism_inner(Axis.VAR, layer)
+
+    def obs_scene_ids(self) -> pa.Array:
+        """Returns a pyarrow array with scene ids that contain obs from this
+        query.
+
+        Lifecycle: experimental
+        """
+        try:
+            obs_scene = self.experiment.obs_spatial_presence
+        except KeyError as ke:
+            raise KeyError("Missing obs_scene") from ke
+        if not isinstance(obs_scene, DataFrame):
+            raise TypeError("obs_scene must be a dataframe.")
+
+        full_table = obs_scene.read(
+            coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
+            result_order=ResultOrder.COLUMN_MAJOR,
+            value_filter="data != 0",
+        ).concat()
+
+        return pacomp.unique(full_table["scene_id"])
+
+    def var_scene_ids(self) -> pa.Array:
+        """Return a pyarrow array with scene ids that contain var from this
+        query.
+
+        Lifecycle: experimental
+        """
+        try:
+            var_scene = self._ms.var_spatial_presence
+        except KeyError as ke:
+            raise KeyError("Missing var_scene") from ke
+        if not isinstance(var_scene, DataFrame):
+            raise TypeError("var_scene must be a dataframe.")
+
+        full_table = var_scene.read(
+            coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
+            result_order=ResultOrder.COLUMN_MAJOR,
+            value_filter="data != 0",
+        ).concat()
+
+        return pacomp.unique(full_table["scene_id"])
+
+    def to_anndata(
+        self,
+        X_name: str,
+        *,
+        column_names: Optional[AxisColumnNames] = None,
+        X_layers: Sequence[str] = (),
+        obsm_layers: Sequence[str] = (),
+        obsp_layers: Sequence[str] = (),
+        varm_layers: Sequence[str] = (),
+        varp_layers: Sequence[str] = (),
+        drop_levels: bool = False,
+    ) -> AnnData:
+        ad = self._read(
+            X_name,
+            column_names=column_names or AxisColumnNames(obs=None, var=None),
+            X_layers=X_layers,
+            obsm_layers=obsm_layers,
+            obsp_layers=obsp_layers,
+            varm_layers=varm_layers,
+            varp_layers=varp_layers,
+        ).to_anndata()
+
+        # Drop unused categories on axis dataframes if requested
+        if drop_levels:
+            for name in ad.obs:
+                if ad.obs[name].dtype.name == "category":
+                    ad.obs[name] = ad.obs[name].cat.remove_unused_categories()
+            for name in ad.var:
+                if ad.var[name].dtype.name == "category":
+                    ad.var[name] = ad.var[name].cat.remove_unused_categories()
+
+        return ad
+
+    # Context management
+
+    def close(self) -> None:
+        """Releases resources associated with this query.
+
+        This method must be idempotent.
+
+        Lifecycle: maturing
+        """
+        # Because this may be called during ``__del__`` when we might be getting
+        # disassembled, sometimes ``_threadpool_`` is simply missing.
+        # Only try to shut it down if it still exists.
+        pool = getattr(self, "_threadpool_", None)
+        if pool is None:
+            return
+        pool.shutdown()
+        self._threadpool_ = None
+
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(self, *_: Any) -> None:
+        self.close()
+
+    def __del__(self) -> None:
+        """Ensure that we're closed when our last ref disappears."""
+        self.close()
+        # If any superclass in our MRO has a __del__, call it.
+        sdel = getattr(super(), "__del__", lambda: None)
+        sdel()
+
+    # Internals
+
+    def _read(
+        self,
+        X_name: str,
+        *,
+        column_names: AxisColumnNames,
+        X_layers: Sequence[str],
+        obsm_layers: Sequence[str] = (),
+        obsp_layers: Sequence[str] = (),
+        varm_layers: Sequence[str] = (),
+        varp_layers: Sequence[str] = (),
+    ) -> AxisQueryResult:
+        """Reads the entire query result in memory.
+
+        This is a low-level routine intended to be used by loaders for other
+        in-core formats, such as AnnData, which can be created from the
+        resulting objects.
+
+        Args:
+            X_name: The X layer to read and return in the ``X`` slot.
+            column_names: The columns in the ``var`` and ``obs`` dataframes
+                to read.
+            X_layers: Additional X layers to read and return
+                in the ``layers`` slot.
+            obsm_layers:
+                Additional obsm layers to read and return in the obsm slot.
+            obsp_layers:
+                Additional obsp layers to read and return in the obsp slot.
+            varm_layers:
+                Additional varm layers to read and return in the varm slot.
+            varp_layers:
+                Additional varp layers to read and return in the varp slot.
+        """
+        x_collection = self._ms.X
+        all_x_names = [X_name] + list(X_layers)
+        all_x_arrays: Dict[str, SparseNDArray] = {}
+        for _xname in all_x_names:
+            if not isinstance(_xname, str) or not _xname:
+                raise ValueError("X layer names must be specified as a string.")
+            if _xname not in x_collection:
+                raise ValueError("Unknown X layer name")
+            x_array = x_collection[_xname]
+            if not isinstance(x_array, SparseNDArray):
+                raise NotImplementedError("Dense array unsupported")
+            all_x_arrays[_xname] = x_array
+
+        def _read_axis_mappings(
+            fn: Callable[[Axis, str], npt.NDArray[Any]],
+            axis: Axis,
+            keys: Sequence[str],
+        ) -> Dict[str, npt.NDArray[Any]]:
+            return {key: fn(axis, key) for key in keys}
+
+        obsm_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axism_inner_ndarray, Axis.OBS, obsm_layers
+        )
+        obsp_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axisp_inner_ndarray, Axis.OBS, obsp_layers
+        )
+        varm_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axism_inner_ndarray, Axis.VAR, varm_layers
+        )
+        varp_ft = self._threadpool.submit(
+            _read_axis_mappings, self._axisp_inner_ndarray, Axis.VAR, varp_layers
+        )
+
+        obs_table, var_table = self._read_both_axes(column_names)
+
+        obs_joinids = self.obs_joinids()
+        var_joinids = self.var_joinids()
+
+        x_matrices = {
+            _xname: (
+                _fast_csr.read_csr(
+                    layer,
+                    obs_joinids,
+                    var_joinids,
+                    index_factory=self._index_factory,
+                ).to_scipy()
+            )
+            for _xname, layer in all_x_arrays.items()
+        }
+
+        x = x_matrices.pop(X_name)
+
+        obs = obs_table.to_pandas()
+        obs.index = obs.index.astype(str)
+
+        var = var_table.to_pandas()
+        var.index = var.index.astype(str)
+
+        return AxisQueryResult(
+            obs=obs,
+            var=var,
+            X=x,
+            obsm=obsm_ft.result(),
+            obsp=obsp_ft.result(),
+            varm=varm_ft.result(),
+            varp=varp_ft.result(),
+            X_layers=x_matrices,
+        )
+
+    def _read_both_axes(
+        self,
+        column_names: AxisColumnNames,
+    ) -> Tuple[pa.Table, pa.Table]:
+        """Reads both axes in their entirety, ensuring soma_joinid is retained."""
+        obs_ft = self._threadpool.submit(
+            self._read_axis_dataframe,
+            Axis.OBS,
+            column_names,
+        )
+        var_ft = self._threadpool.submit(
+            self._read_axis_dataframe,
+            Axis.VAR,
+            column_names,
+        )
+        return obs_ft.result(), var_ft.result()
+
+    def _read_axis_dataframe(
+        self,
+        axis: Axis,
+        axis_column_names: AxisColumnNames,
+    ) -> pa.Table:
+        """Reads the specified axis. Will cache join IDs if not present."""
+        column_names = axis_column_names.get(axis.value)
+
+        axis_df = axis.getattr_from(self, pre="_", suf="_df")
+        assert isinstance(axis_df, DataFrame)
+        axis_query = axis.getattr_from(self._matrix_axis_query)
+
+        # If we can cache join IDs, prepare to add them to the cache.
+        joinids_cached = self._joinids._is_cached(axis)
+        query_columns = column_names
+        added_soma_joinid_to_columns = False
+        if (
+            not joinids_cached
+            and column_names is not None
+            and "soma_joinid" not in column_names
+        ):
+            # If we want to fill the join ID cache, ensure that we query the
+            # soma_joinid column so that it is included in the results.
+            # We'll filter it out later.
+            query_columns = ["soma_joinid"] + list(column_names)
+            added_soma_joinid_to_columns = True
+
+        # Do the actual query.
+        arrow_table = axis_df.read(
+            coords=axis_query.coords,
+            value_filter=axis_query.value_filter,
+            column_names=query_columns,
+        ).concat()
+
+        # Update the cache if needed. We can do this because no matter what
+        # other columns are queried for, the contents of the ``soma_joinid``
+        # column will be the same and can be safely stored.
+        if not joinids_cached:
+            setattr(
+                self._joinids,
+                axis.value,
+                arrow_table.column("soma_joinid").combine_chunks(),
+            )
+
+        # Drop soma_joinid column if we added it solely for use in filling
+        # the joinid cache.
+        if added_soma_joinid_to_columns:
+            arrow_table = arrow_table.drop(["soma_joinid"])
+        return arrow_table
+
+    def _axisp_inner(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> SparseRead:
+        p_name = f"{axis.value}p"
+        try:
+            ms = self._ms
+            axisp = ms.obsp if axis.value == "obs" else ms.varp
+        except (AttributeError, KeyError):
+            raise ValueError(f"Measurement does not contain {p_name} data")
+
+        try:
+            ap_layer = axisp[layer]
+        except KeyError:
+            raise ValueError(f"layer {layer!r} is not available in {p_name}")
+        if not isinstance(ap_layer, SparseNDArray):
+            raise TypeError(
+                f"Unexpected SOMA type {type(ap_layer).__name__}"
+                f" stored in {p_name} layer {layer!r}"
+            )
+
+        joinids = axis.getattr_from(self._joinids)
+        return ap_layer.read((joinids, joinids))
+
+    def _axism_inner(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> SparseRead:
+        m_name = f"{axis.value}m"
+
+        try:
+            ms = self._ms
+            axism = ms.obsm if axis.value == "obs" else ms.varm
+        except (AttributeError, KeyError):
+            raise ValueError(f"Measurement does not contain {m_name} data")
+
+        try:
+            axism_layer = axism[layer]
+        except KeyError:
+            raise ValueError(f"layer {layer!r} is not available in {m_name}")
+
+        if not isinstance(axism_layer, SparseNDArray):
+            raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")
+
+        joinids = axis.getattr_from(self._joinids)
+        return axism_layer.read((joinids, slice(None)))
+
+    def _convert_to_ndarray(
+        self, axis: Axis, table: pa.Table, n_row: int, n_col: int
+    ) -> npt.NDArray[np.float32]:
+        indexer = cast(
+            Callable[[Numpyable], npt.NDArray[np.intp]],
+            axis.getattr_from(self.indexer, pre="by_"),
+        )
+        idx = indexer(table["soma_dim_0"])
+        z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
+        np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
+        return z.reshape(n_row, n_col)
+
+    def _axisp_inner_ndarray(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> npt.NDArray[np.float32]:
+        n_row = n_col = len(axis.getattr_from(self._joinids))
+
+        table = self._axisp_inner(axis, layer).tables().concat()
+        return self._convert_to_ndarray(axis, table, n_row, n_col)
+
+    def _axism_inner_ndarray(
+        self,
+        axis: Axis,
+        layer: str,
+    ) -> npt.NDArray[np.float32]:
+        table = self._axism_inner(axis, layer).tables().concat()
+
+        n_row = len(axis.getattr_from(self._joinids))
+        n_col = len(table["soma_dim_1"].unique())
+
+        return self._convert_to_ndarray(axis, table, n_row, n_col)
+
+    @property
+    def _obs_df(self) -> DataFrame:
+        return self.experiment.obs
+
+    @property
+    def _ms(self) -> Measurement:
+        return self.experiment.ms[self.measurement_name]
+
+    @property
+    def _var_df(self) -> DataFrame:
+        return self._ms.var
+
+    @property
+    def _threadpool(self) -> ThreadPoolExecutor:
+        """
+        Returns the threadpool provided by the experiment's context.
+        If not available, creates a thread pool just in time."""
+        context = self.experiment.context
+        if context and context.threadpool:
+            return context.threadpool
+
+        if self._threadpool_ is None:
+            self._threadpool_ = ThreadPoolExecutor()
+        return self._threadpool_
+
+
+@attrs.define(frozen=True)
+class MatrixAxisQuery:
+    """The per-axis user query definition. Private."""
+
+    obs: AxisQuery
+    var: AxisQuery
+
+
+@attrs.define
+class JoinIDCache:
+    """A cache for per-axis join ids in the query. Private."""
+
+    owner: ExperimentAxisQuery
+
+    _cached_obs: Optional[pa.IntegerArray] = None
+    _cached_var: Optional[pa.IntegerArray] = None
+
+    def _is_cached(self, axis: Axis) -> bool:
+        field = "_cached_" + axis.value
+        return getattr(self, field) is not None
+
+    def preload(self, pool: ThreadPoolExecutor) -> None:
+        if self._cached_obs is not None and self._cached_var is not None:
+            return
+        obs_ft = pool.submit(lambda: self.obs)
+        var_ft = pool.submit(lambda: self.var)
+        # Wait for them and raise in case of error.
+        obs_ft.result()
+        var_ft.result()
+
+    @property
+    def obs(self) -> pa.IntegerArray:
+        """Join IDs for the obs axis. Will load and cache if not already."""
+        if not self._cached_obs:
+            self._cached_obs = load_joinids(
+                self.owner._obs_df, self.owner._matrix_axis_query.obs
+            )
+        return self._cached_obs
+
+    @obs.setter
+    def obs(self, val: pa.IntegerArray) -> None:
+        self._cached_obs = val
+
+    @property
+    def var(self) -> pa.IntegerArray:
+        """Join IDs for the var axis. Will load and cache if not already."""
+        if not self._cached_var:
+            self._cached_var = load_joinids(
+                self.owner._var_df, self.owner._matrix_axis_query.var
+            )
+        return self._cached_var
+
+    @var.setter
+    def var(self, val: pa.IntegerArray) -> None:
+        self._cached_var = val
+
+
+def load_joinids(df: DataFrame, axq: AxisQuery) -> pa.IntegerArray:
+    tbl = df.read(
+        axq.coords,
+        value_filter=axq.value_filter,
+        column_names=["soma_joinid"],
+    ).concat()
+    return tbl.column("soma_joinid").combine_chunks()
diff --git a/apis/python/tests/test_experiment_query.py b/apis/python/tests/test_experiment_query.py
index 1831e5e59c..401831f29d 100644
--- a/apis/python/tests/test_experiment_query.py
+++ b/apis/python/tests/test_experiment_query.py
@@ -4,6 +4,7 @@
 from typing import Tuple
 from unittest import mock
 
+import attrs
 import numpy as np
 import pandas as pd
 import pyarrow as pa
@@ -13,8 +14,10 @@
 from somacore import AxisQuery, options
 
 import tiledbsoma as soma
-from tiledbsoma import SOMATileDBContext, _factory
+from tiledbsoma import SOMATileDBContext, _factory, pytiledbsoma
 from tiledbsoma._collection import CollectionBase
+from tiledbsoma._experiment import Experiment
+from tiledbsoma._query import Axis, ExperimentAxisQuery
 from tiledbsoma.experiment_query import X_as_series
 
 from tests._util import raises_no_typeguard
@@ -58,7 +61,7 @@ def soma_experiment(
     varp_layer_names,
     obsm_layer_names,
     varm_layer_names,
-):
+) -> Experiment:
     with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp:
         add_dataframe(exp, "obs", n_obs)
         ms = exp.add_new_collection("ms")
@@ -88,7 +91,7 @@ def soma_experiment(
             for varm_layer_name in varm_layer_names:
                 add_sparse_array(varm, varm_layer_name, (n_vars, N_FEATURES))
 
-    return _factory.open((tmp_path / "exp").as_posix())
+    return Experiment.open((tmp_path / "exp").as_posix())
 
 
 def get_soma_experiment_with_context(soma_experiment, context):
@@ -99,7 +102,7 @@ def get_soma_experiment_with_context(soma_experiment, context):
 @pytest.mark.parametrize("n_obs,n_vars,X_layer_names", [(101, 11, ("raw", "extra"))])
 def test_experiment_query_all(soma_experiment):
     """Test a query with default obs_query / var_query -- i.e., query all."""
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         assert query.n_obs == 101
         assert query.n_vars == 11
 
@@ -304,7 +307,7 @@ def test_experiment_query_batch_size(soma_experiment):
     This test merely verifies that the batch_size parameter is accepted
     but as a no-op.
     """
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         tbls = query.obs(batch_size=options.BatchSize(count=100))
         assert len(list(tbls)) == 1  # batch_size currently not implemented
 
@@ -315,7 +318,7 @@ def test_experiment_query_partitions(soma_experiment):
     partitions is currently not supported by this implementation of SOMA.
     This test checks if a ValueError is raised if a partitioning is requested.
     """
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         with pytest.raises(ValueError):
             query.obs(partitions=options.IOfN(i=0, n=3)).concat()
 
@@ -328,7 +331,7 @@ def test_experiment_query_partitions(soma_experiment):
 
 @pytest.mark.parametrize("n_obs,n_vars", [(10, 10)])
 def test_experiment_query_result_order(soma_experiment):
-    with soma.ExperimentAxisQuery(soma_experiment, "RNA") as query:
+    with ExperimentAxisQuery(soma_experiment, "RNA") as query:
         # Since obs is 1-dimensional, row-major and column-major should be the same
         obs_data_row_major = (
             query.obs(result_order="row-major").concat()["label"].to_numpy()
@@ -393,11 +396,10 @@ def test_experiment_axis_query_with_none(soma_experiment):
     """Test query by value filter"""
     obs_label_values = ["3", "7", "38", "99"]
 
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         experiment=soma_experiment,
         measurement_name="RNA",
         obs_query=soma.AxisQuery(value_filter=f"label in {obs_label_values}"),
-        var_query=None,
     ) as query:
         assert query.n_obs == len(obs_label_values)
         assert query.obs().concat()["label"].to_pylist() == obs_label_values
@@ -462,7 +464,7 @@ def test_X_layers(soma_experiment):
 @pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
 def test_experiment_query_indexer(soma_experiment):
     """Test result indexer"""
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         soma_experiment,
         "RNA",
         obs_query=soma.AxisQuery(coords=(slice(1, 10),)),
@@ -501,7 +503,7 @@ def test_experiment_query_indexer(soma_experiment):
 
 
 @pytest.mark.parametrize("n_obs,n_vars", [(2833, 107)])
-def test_error_corners(soma_experiment: soma.Experiment):
+def test_error_corners(soma_experiment: Experiment):
     """Verify a couple of error conditions / corner cases."""
     # Unknown Measurement name
     with pytest.raises(ValueError):
@@ -532,16 +534,16 @@ def test_error_corners(soma_experiment: soma.Experiment):
         with soma_experiment.axis_query("RNA") as query:
             with raises_no_typeguard(KeyError):
                 next(query.X(lyr_name))
-            with pytest.raises(ValueError):
+            with raises_no_typeguard(ValueError):
                 next(query.obsp(lyr_name))
-            with pytest.raises(ValueError):
+            with raises_no_typeguard(ValueError):
                 next(query.varp(lyr_name))
 
 
 @pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
-def test_query_cleanup(soma_experiment: soma.Experiment):
+def test_query_cleanup(soma_experiment: Experiment):
     """
-    Verify soma.Experiment.query works as context manager and stand-alone,
+    Verify Experiment.query works as context manager and stand-alone,
     and that it cleans up correctly.
     """
     from contextlib import closing
@@ -574,16 +576,16 @@ def test_query_cleanup(soma_experiment: soma.Experiment):
 def test_experiment_query_obsp_varp_obsm_varm(soma_experiment):
     obs_slice = slice(3, 72)
     var_slice = slice(7, 21)
-    with soma.ExperimentAxisQuery(
+    with ExperimentAxisQuery(
         soma_experiment,
         "RNA",
-        obs_query=soma.AxisQuery(coords=(obs_slice,)),
-        var_query=soma.AxisQuery(coords=(var_slice,)),
+        obs_query=AxisQuery(coords=(obs_slice,)),
+        var_query=AxisQuery(coords=(var_slice,)),
     ) as query:
         assert query.n_obs == obs_slice.stop - obs_slice.start + 1
         assert query.n_vars == var_slice.stop - var_slice.start + 1
 
-        with pytest.raises(ValueError):
+        with raises_no_typeguard(ValueError):
             next(query.obsp("no-such-layer"))
 
         with pytest.raises(ValueError):
@@ -682,44 +684,40 @@ def test_experiment_query_to_anndata_obsp_varp(soma_experiment):
 
 def test_axis_query():
     """Basic test of the AxisQuery class"""
-    assert soma.AxisQuery().coords == ()
-    assert soma.AxisQuery().value_filter is None
-    assert soma.AxisQuery() == soma.AxisQuery(coords=())
+    assert AxisQuery().coords == ()
+    assert AxisQuery().value_filter is None
+    assert AxisQuery() == AxisQuery(coords=())
 
-    assert soma.AxisQuery(coords=(1,)).coords == (1,)
-    assert soma.AxisQuery(coords=(slice(1, 2),)).coords == (slice(1, 2),)
-    assert soma.AxisQuery(coords=((1, 88),)).coords == ((1, 88),)
+    assert AxisQuery(coords=(1,)).coords == (1,)
+    assert AxisQuery(coords=(slice(1, 2),)).coords == (slice(1, 2),)
+    assert AxisQuery(coords=((1, 88),)).coords == ((1, 88),)
 
-    assert soma.AxisQuery(coords=(1, 2)).coords == (1, 2)
-    assert soma.AxisQuery(coords=(slice(1, 2), slice(None))).coords == (
+    assert AxisQuery(coords=(1, 2)).coords == (1, 2)
+    assert AxisQuery(coords=(slice(1, 2), slice(None))).coords == (
         slice(1, 2),
         slice(None),
     )
-    assert soma.AxisQuery(coords=(slice(1, 2),)).value_filter is None
+    assert AxisQuery(coords=(slice(1, 2),)).value_filter is None
 
-    assert soma.AxisQuery(value_filter="foo == 'bar'").value_filter == "foo == 'bar'"
-    assert soma.AxisQuery(value_filter="foo == 'bar'").coords == ()
+    assert AxisQuery(value_filter="foo == 'bar'").value_filter == "foo == 'bar'"
+    assert AxisQuery(value_filter="foo == 'bar'").coords == ()
 
-    assert soma.AxisQuery(
-        coords=(slice(1, 100),), value_filter="foo == 'bar'"
-    ).coords == (
+    assert AxisQuery(coords=(slice(1, 100),), value_filter="foo == 'bar'").coords == (
         slice(1, 100),
     )
     assert (
-        soma.AxisQuery(
-            coords=(slice(1, 100),), value_filter="foo == 'bar'"
-        ).value_filter
+        AxisQuery(coords=(slice(1, 100),), value_filter="foo == 'bar'").value_filter
         == "foo == 'bar'"
     )
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(coords=True)
+        AxisQuery(coords=True)
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(value_filter=[])
+        AxisQuery(value_filter=[])
 
     with pytest.raises(TypeError):
-        soma.AxisQuery(coords=({},))
+        AxisQuery(coords=({},))
 
 
 def test_X_as_series():
@@ -801,10 +799,10 @@ def test_experiment_query_column_names(soma_experiment):
     # column_names and value_filter
     with soma_experiment.axis_query(
         "RNA",
-        obs_query=soma.AxisQuery(
+        obs_query=AxisQuery(
             value_filter="label in [" + ",".join(f"'{i}'" for i in range(101)) + "]"
         ),
-        var_query=soma.AxisQuery(
+        var_query=AxisQuery(
             value_filter="label in [" + ",".join(f"'{i}'" for i in range(99)) + "]"
         ),
     ) as query:
@@ -860,7 +858,7 @@ def test_experiment_query_mp_disjoint_arrow_coords(soma_experiment):
     for ids in slices:
         with soma_experiment.axis_query(
             "RNA",
-            obs_query=soma.AxisQuery(coords=(ids,)),
+            obs_query=AxisQuery(coords=(ids,)),
         ) as query:
             assert query.obs_joinids() == ids
 
@@ -951,7 +949,7 @@ def test_empty_categorical_query(conftest_pbmc_small_exp):
         measurement_name="RNA", obs_query=AxisQuery(value_filter='groups == "foo"')
     )
     # Empty query on a categorical column raised ArrowInvalid before TileDB 2.21; see https://github.com/single-cell-data/TileDB-SOMA/pull/2299
-    m = re.fullmatch(r"libtiledb=(\d+\.\d+\.\d+)", soma.pytiledbsoma.version())
+    m = re.fullmatch(r"libtiledb=(\d+\.\d+\.\d+)", pytiledbsoma.version())
     version = m.group(1).split(".")
     major, minor = int(version[0]), int(version[1])
 
@@ -959,3 +957,24 @@ def test_empty_categorical_query(conftest_pbmc_small_exp):
     with ctx:
         obs = q.obs().concat()
         assert len(obs) == 0
+
+
+@attrs.define(frozen=True)
+class IHaveObsVarStuff:
+    obs: int
+    var: int
+    the_obs_suf: str
+    the_var_suf: str
+
+
+def test_axis_helpers() -> None:
+    thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
+    assert 1 == Axis.OBS.getattr_from(thing)
+    assert 2 == Axis.VAR.getattr_from(thing)
+    assert "observe" == Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
+    assert "vary" == Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
+    ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
+    assert "erve" == Axis.OBS.getitem_from(ovdict)
+    assert "y" == Axis.VAR.getitem_from(ovdict)
+    assert "hide" == Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
+    assert "???" == Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")