From c440fea7d58e6351250bbe0bb8013aa584df3a42 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 19 Dec 2024 17:11:03 -0500 Subject: [PATCH] support `somacore>=1.0.24` / `tiledbsoma>=1.15.0rc4` (#19) --- .github/workflows/python-tilledbsoma-ml-compat.yml | 1 + src/tiledbsoma_ml/pytorch.py | 8 +++++++- tests/_utils.py | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml index 6cbde0b..d4ae9b7 100644 --- a/.github/workflows/python-tilledbsoma-ml-compat.yml +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -31,6 +31,7 @@ jobs: - "tiledbsoma~=1.12.0" - "tiledbsoma~=1.13.0" - "tiledbsoma~=1.14.0" + - "tiledbsoma~=1.15.0rc4" runs-on: ${{ matrix.os }} diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index ca275f1..09604c8 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -27,7 +27,6 @@ import scipy.sparse as sparse import tiledbsoma as soma import torch -from somacore.query._eager_iter import EagerIterator as _EagerIterator from tiledbsoma_ml._csr import CSR_IO_Buffer from tiledbsoma_ml._distributed import ( @@ -37,6 +36,13 @@ from tiledbsoma_ml._experiment_locator import ExperimentLocator from tiledbsoma_ml._utils import NDArrayNumber, batched, splits +try: + # somacore<1.0.24 / tiledbsoma<1.15 + from somacore.query._eager_iter import EagerIterator as _EagerIterator +except ImportError: + # somacore>=1.0.24 / tiledbsoma>=1.15 + from tiledbsoma._eager_iter import EagerIterator as _EagerIterator + logger = logging.getLogger("tiledbsoma_ml.pytorch") NDArrayJoinId = npt.NDArray[np.int64] diff --git a/tests/_utils.py b/tests/_utils.py index 6f51fdb..0921cbe 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -71,6 +71,7 @@ def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: ] ), index_column_names=["soma_joinid"], + domain=((value_range.start, value_range.stop),), ) df.write( pa.Table.from_pydict(