Skip to content

Commit

Permalink
[python] Add default thread pool in SOMATileDBContext (#2001)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzi authored Feb 27, 2024
1 parent 2310054 commit 3f64022
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: mypy
additional_dependencies:
- "pandas-stubs==1.5.3.230214"
- "somacore==1.0.7"
- "somacore==1.0.8"
- "types-setuptools==67.4.0.3"
args: ["--config-file=apis/python/pyproject.toml", "apis/python/src", "apis/python/devtools"]
pass_filenames: false
2 changes: 1 addition & 1 deletion apis/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def run(self):
"pyarrow>=9.0.0; platform_system!='Darwin'",
"scanpy>=1.9.2",
"scipy",
"somacore==1.0.7",
"somacore==1.0.8",
"tiledb~=0.25.0",
"typing-extensions", # Note "-" even though `import typing_extensions`
],
Expand Down
42 changes: 17 additions & 25 deletions apis/python/src/tiledbsoma/_read_iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import abc
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -88,7 +89,6 @@ def __init__(
size: Optional[Union[int, Sequence[int]]] = None,
reindex_disable_on_axis: Optional[Union[int, Sequence[int]]] = None,
eager: bool = True,
eager_iterator_pool: Optional[ThreadPoolExecutor] = None,
context: Optional[SOMATileDBContext] = None,
):
super().__init__()
Expand All @@ -97,7 +97,14 @@ def __init__(
self.array = array
self.sr = sr
self.eager = eager
self.eager_iterator_pool = eager_iterator_pool

# Assign a thread pool from the context, or create a new one if no context
# is available
if context is not None:
self._threadpool = context.threadpool
else:
self._threadpool = futures.ThreadPoolExecutor()

self.context = context

# raises on various error checks, AND normalizes args
Expand Down Expand Up @@ -261,19 +268,11 @@ class BlockwiseTableReadIter(BlockwiseReadIterBase[BlockwiseTableReadIterResult]

def _create_reader(self) -> Iterator[BlockwiseTableReadIterResult]:
"""Private. Blockwise Arrow Table iterator, restricted to a single axis"""
if self.eager and not self.eager_iterator_pool:
with ThreadPoolExecutor() as _pool:
yield from (
self._reindexed_table_reader(_pool)
if self.axes_to_reindex
else self._table_reader()
)
else:
yield from (
self._reindexed_table_reader(_pool=self.eager_iterator_pool)
if self.axes_to_reindex
else self._table_reader()
)
yield from (
self._reindexed_table_reader(_pool=self._threadpool)
if self.axes_to_reindex
else self._table_reader()
)


class BlockwiseScipyReadIter(BlockwiseReadIterBase[BlockwiseScipyReadIterResult]):
Expand Down Expand Up @@ -331,16 +330,9 @@ def _create_reader(self) -> Iterator[BlockwiseScipyReadIterResult]:
"""
Private. Iterator over SparseNDArray producing sequence of scipy sparse matrix.
"""

if self.eager and not self.eager_iterator_pool:
with ThreadPoolExecutor() as _pool:
yield from self._cs_reader(
_pool
) if self.compress else self._coo_reader(_pool)
else:
yield from self._cs_reader(
_pool=self.eager_iterator_pool
) if self.compress else self._coo_reader(_pool=self.eager_iterator_pool)
yield from self._cs_reader(
_pool=self._threadpool
) if self.compress else self._coo_reader(_pool=self._threadpool)

def _sorted_tbl_reader(
self, _pool: Optional[ThreadPoolExecutor] = None
Expand Down
24 changes: 21 additions & 3 deletions apis/python/src/tiledbsoma/options/_soma_tiledb_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import functools
import threading
import time
from concurrent import futures
from typing import Any, Dict, Mapping, Optional, Union

import tiledb
from somacore import ContextBase
from typing_extensions import Self

from .. import pytiledbsoma as clib
Expand Down Expand Up @@ -48,7 +50,7 @@ def _maybe_timestamp_ms(input: Optional[OpenTimestamp]) -> Optional[int]:
"""Sentinel object to distinguish default parameters from None."""


class SOMATileDBContext:
class SOMATileDBContext(ContextBase):
"""Maintains TileDB-specific context for TileDB-SOMA objects.
This context can be shared across multiple objects,
including having a child object inherit it from its parent.
Expand All @@ -66,6 +68,7 @@ def __init__(
tiledb_ctx: Optional[tiledb.Ctx] = None,
tiledb_config: Optional[Dict[str, Union[str, float]]] = None,
timestamp: Optional[OpenTimestamp] = None,
threadpool: Optional[futures.ThreadPoolExecutor] = None,
) -> None:
"""Initializes a new SOMATileDBContext.
Expand Down Expand Up @@ -109,6 +112,10 @@ def __init__(
Set to 0xFFFFFFFFFFFFFFFF (UINT64_MAX) to get the absolute
latest revision (i.e., including changes that occur "after"
the current wall time) as of when *each* object is opened.
threadpool: A threadpool to use for concurrent operations. If not
provided, a new ThreadPoolExecutor will be created with
default settings.
"""
if tiledb_ctx is not None and tiledb_config is not None:
raise ValueError(
Expand All @@ -131,8 +138,10 @@ def __init__(
"""The TileDB context to use, either provided or lazily constructed."""
self._timestamp_ms = _maybe_timestamp_ms(timestamp)

"""Lazily construct clib.SOMAContext."""
self.threadpool = threadpool or futures.ThreadPoolExecutor()
"""User specified threadpool. If None, we'll instantiate one ourselves."""
self._native_context: Optional[clib.SOMAContext] = None
"""Lazily construct clib.SOMAContext."""

@property
def timestamp_ms(self) -> Optional[int]:
Expand Down Expand Up @@ -211,6 +220,7 @@ def replace(
tiledb_config: Optional[Dict[str, Any]] = None,
tiledb_ctx: Optional[tiledb.Ctx] = None,
timestamp: Optional[OpenTimestamp] = _SENTINEL, # type: ignore[assignment]
threadpool: Optional[futures.ThreadPoolExecutor] = _SENTINEL, # type: ignore[assignment]
) -> Self:
"""Create a copy of the context, merging changes.
Expand All @@ -226,6 +236,8 @@ def replace(
Explicitly passing ``None`` will remove the timestamp.
For details, see the description of ``timestamp``
in :meth:`__init__`.
threadpool:
A threadpool to replace the current threadpool with.
Lifecycle:
Experimental.
Expand All @@ -250,8 +262,14 @@ def replace(
if timestamp is _SENTINEL:
# Keep the existing timestamp if not overridden.
timestamp = self._timestamp_ms
if threadpool is _SENTINEL:
# Keep the existing threadpool if not overridden.
threadpool = self.threadpool
return type(self)(
tiledb_config=tiledb_config, tiledb_ctx=tiledb_ctx, timestamp=timestamp
tiledb_config=tiledb_config,
tiledb_ctx=tiledb_ctx,
timestamp=timestamp,
threadpool=threadpool,
)

def _open_timestamp_ms(self, in_timestamp: Optional[OpenTimestamp]) -> int:
Expand Down
37 changes: 35 additions & 2 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from concurrent import futures
from typing import Tuple
from unittest import mock

import numpy as np
import pandas as pd
Expand All @@ -8,7 +10,7 @@
from somacore import options

import tiledbsoma as soma
from tiledbsoma import _factory
from tiledbsoma import SOMATileDBContext, _factory
from tiledbsoma._collection import CollectionBase
from tiledbsoma.experiment_query import X_as_series

Expand Down Expand Up @@ -84,6 +86,11 @@ def soma_experiment(
return _factory.open((tmp_path / "exp").as_posix())


def get_soma_experiment_with_context(soma_experiment, context):
soma_experiment.close()
return _factory.open(soma_experiment.uri, context=context)


@pytest.mark.parametrize("n_obs,n_vars,X_layer_names", [(101, 11, ("raw", "extra"))])
def test_experiment_query_all(soma_experiment):
"""Test a query with default obs_query / var_query -- i.e., query all."""
Expand Down Expand Up @@ -517,10 +524,16 @@ def test_error_corners(soma_experiment: soma.Experiment):
def test_query_cleanup(soma_experiment: soma.Experiment):
"""
Verify soma.Experiment.query works as context manager and stand-alone,
and that it cleans up correct.
and that it cleans up correctly.
"""
from contextlib import closing

# Forces a context without a thread pool, which in turn causes ExperimentAxisQuery
# to own (and release) its own thread pool.
context = SOMATileDBContext()
context.threadpool = None
soma_experiment = get_soma_experiment_with_context(soma_experiment, context)

with soma_experiment.axis_query("RNA") as query:
assert query.n_obs == 1001
assert query.n_vars == 99
Expand Down Expand Up @@ -873,3 +886,23 @@ def add_sparse_array(coll: CollectionBase, key: str, shape: Tuple[int, int]) ->
)
)
a.write(tensor)


@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
def test_experiment_query_uses_threadpool_from_context(soma_experiment):
"""
Verify that ExperimentAxisQuery uses the threadpool from the context
"""

pool = mock.Mock(wraps=futures.ThreadPoolExecutor(max_workers=4))
pool.submit.assert_not_called()

context = SOMATileDBContext(threadpool=pool)
soma_experiment = get_soma_experiment_with_context(soma_experiment, context=context)

with soma_experiment.axis_query("RNA") as query:
# to_anndata uses the threadpool
adata = query.to_anndata(X_name="raw")
assert adata is not None

pool.submit.assert_called()
52 changes: 52 additions & 0 deletions apis/python/tests/test_sparse_nd_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import operator
import pathlib
import sys
from concurrent import futures
from typing import Any, Dict, List, Tuple, Union
from unittest import mock

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -1719,3 +1721,53 @@ def test_blockwise_scipy_reindex_disable_major_dim(
.scipy(compress=False)
)
assert isinstance(sp, sparse.coo_matrix)


@pytest.mark.parametrize("density,shape", [(0.1, (100, 100))])
def test_blockwise_iterator_uses_thread_pool_from_context(
a_random_sparse_nd_array: str, shape: Tuple[int, ...]
) -> None:
pool = mock.Mock(wraps=futures.ThreadPoolExecutor(max_workers=2))
pool.submit.assert_not_called()

context = SOMATileDBContext(threadpool=pool)
with soma.open(a_random_sparse_nd_array, mode="r", context=context) as A:
axis = 0
size = 50
tbls = (
A.read()
.blockwise(
axis=axis,
size=size,
)
.tables()
)

# The iteration needs to happen to ensure the threadpool is used
for tbl in tbls:
assert tbl is not None

pool.submit.assert_called()

pool.reset_mock()
pool.submit.assert_not_called()

with soma.open(a_random_sparse_nd_array, mode="r", context=context) as A:
axis = 0
size = 50
arrs = (
A.read()
.blockwise(
axis=axis,
size=size,
)
.scipy()
)

# The iteration needs to happen to ensure the threadpool is used
for arr in arrs:
assert arr is not None

pool.submit.assert_called()

pool.shutdown()

0 comments on commit 3f64022

Please sign in to comment.