Skip to content

Commit

Permalink
[python] Partition sparse matrix reads in tiledbsoma.io.to_anndata (#…
Browse files Browse the repository at this point in the history
…3328)

* first cut at fast CSX conversion

* fix compile time warning

* build python bindings -O3

* typos

* fix clang compile issue

* reformat to repo C++ standards

* fix clang warning about unused captures

* more clang fixes

* fix build error found during CI

* lint

* more lint

* lint chase

* cleanup include statements

* debugging R build

* lint

* cleanup GHA for interop testing

* add -mavx2 for x86 build

* more tests

* comment

* add bounds check for second dimension coordiate

* lint

* test / bug fix argument handling

* hand code AVX specializations for speed eval

* reduce template overhead

* do not use experimental value width dispatching

* lint

* improve exception discrimination

* cleanup size-based value templating

* cleanup

* partition sparse array reads

* clean up C++ namespace

* revert cleanup on request

* PR feedback (thanks John!)

* incorporate more PR fb

* lint

* comments

* PR f/b

* fix compile warnings

* PR f/b
  • Loading branch information
bkmartinjr authored Nov 19, 2024
1 parent d4e5a42 commit 3808ed9
Showing 1 changed file with 63 additions and 27 deletions.
90 changes: 63 additions & 27 deletions apis/python/src/tiledbsoma/io/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import anndata as ad
import numpy as np
import pandas as pd
import pyarrow as pa
import scipy.sparse as sp

from .. import (
Expand Down Expand Up @@ -101,36 +102,66 @@ def to_h5ad(
)


# ----------------------------------------------------------------
def _read_partitioned_sparse(X: SparseNDArray, d0_size: int) -> pa.Table:
# Partition dimension 0 based on a target point count and average
# density of matrix. Magic number determined empirically, as a tradeoff
# between concurrency and fixed query overhead.
tgt_point_count = 92 * 1024**2
partition_sz = (
max(1024 * round(d0_size * tgt_point_count / X.nnz / 1024), 1024)
if X.nnz > 0
else d0_size
)
partitions = [
slice(st, min(st + partition_sz - 1, d0_size - 1))
for st in range(0, d0_size, partition_sz)
]
n_partitions = len(partitions)

def _read_sparse_X(A: SparseNDArray, row_slc: slice) -> pa.Table:
return A.read(coords=(row_slc,)).tables().concat()

if n_partitions > 1: # don't consume threads unless there is a reason to do so
return pa.concat_tables(
X.context.threadpool.map(_read_sparse_X, (X,) * n_partitions, partitions)
)
else:
return _read_sparse_X(X, partitions[0])


def _extract_X_key(
measurement: Measurement,
X_layer_name: str,
nobs: int,
nvar: int,
) -> Matrix:
measurement: Measurement, X_layer_name: str, nobs: int, nvar: int
) -> Future[Matrix]:
"""Helper function for to_anndata"""
if X_layer_name not in measurement.X:
raise ValueError(
f"X_layer_name {X_layer_name} not found in data: {measurement.X.keys()}"
)

# Acquire handle to TileDB-SOMA data
soma_X_data_handle = measurement.X[X_layer_name]
X = measurement.X[X_layer_name]
tp = X.context.threadpool

# Read data from SOMA into memory
if isinstance(soma_X_data_handle, DenseNDArray):
data = soma_X_data_handle.read((slice(None), slice(None))).to_numpy()
elif isinstance(soma_X_data_handle, SparseNDArray):
data = conversions.csr_from_coo_table(
soma_X_data_handle.read().tables().concat(),
nobs,
nvar,
context=soma_X_data_handle.context,
)
else:
raise TypeError(f"Unexpected NDArray type {type(soma_X_data_handle)}")
if isinstance(X, DenseNDArray):

def _read_dense_X(A: DenseNDArray) -> Matrix:
return A.read((slice(None), slice(None))).to_numpy()

return tp.submit(_read_dense_X, X)

return data
elif isinstance(X, SparseNDArray):

def _read_X_partitions() -> Matrix:
stk_of_coo = _read_partitioned_sparse(X, nobs)
return conversions.csr_from_coo_table(
stk_of_coo, nobs, nvar, context=X.context
)

return X.context.threadpool.submit(_read_X_partitions)

else:
raise TypeError(f"Unexpected NDArray type {type(X)}")


def _read_dataframe(
Expand Down Expand Up @@ -277,7 +308,7 @@ def to_anndata(
nobs = len(obs_df.index)
nvar = len(var_df.index)

anndata_layers = {}
anndata_layers_futures = {}

# Let them use
# extra_X_layer_names=exp.ms["RNA"].X.keys()
Expand All @@ -294,17 +325,15 @@ def to_anndata(

anndata_X_future: Future[Matrix] | None = None
if X_layer_name is not None:
anndata_X_future = tp.submit(
_extract_X_key, measurement, X_layer_name, nobs, nvar
)
anndata_X_future = _extract_X_key(measurement, X_layer_name, nobs, nvar)

if extra_X_layer_names is not None:
for extra_X_layer_name in extra_X_layer_names:
if extra_X_layer_name == X_layer_name:
continue
assert extra_X_layer_name is not None # appease linter; already checked
data = _extract_X_key(measurement, extra_X_layer_name, nobs, nvar)
anndata_layers[extra_X_layer_name] = data
anndata_layers_futures[extra_X_layer_name] = data

if obsm_varm_width_hints is None:
obsm_varm_width_hints = {}
Expand Down Expand Up @@ -341,7 +370,10 @@ def to_anndata(
def load_obsp(measurement: Measurement, key: str, nobs: int) -> sp.csr_matrix:
A = measurement.obsp[key]
return conversions.csr_from_coo_table(
A.read().tables().concat(), nobs, nobs, A.context
_read_partitioned_sparse(A, nobs),
nobs,
nobs,
A.context,
)

for key in measurement.obsp.keys():
Expand All @@ -353,7 +385,10 @@ def load_obsp(measurement: Measurement, key: str, nobs: int) -> sp.csr_matrix:
def load_varp(measurement: Measurement, key: str, nvar: int) -> sp.csr_matrix:
A = measurement.varp[key]
return conversions.csr_from_coo_table(
A.read().tables().concat(), nvar, nvar, A.context
_read_partitioned_sparse(A, nvar),
nvar,
nvar,
A.context,
)

for key in measurement.varp.keys():
Expand All @@ -376,6 +411,7 @@ def load_varp(measurement: Measurement, key: str, nvar: int) -> sp.csr_matrix:
obsp = _resolve_futures(obsp)
varp = _resolve_futures(varp)
anndata_X = anndata_X_future.result() if anndata_X_future is not None else None
anndata_layers = _resolve_futures(anndata_layers_futures)
uns: UnsDict = (
_resolve_futures(uns_future.result(), deep=True)
if uns_future is not None
Expand Down Expand Up @@ -423,7 +459,7 @@ def _extract_obsm_or_varm(
# 3.8 and we still support Python 3.7
return matrix

matrix_tbl = soma_nd_array.read().tables().concat()
matrix_tbl = _read_partitioned_sparse(soma_nd_array, num_rows)

# Problem to solve: whereas for other sparse arrays we have:
#
Expand Down

0 comments on commit 3808ed9

Please sign in to comment.