From 37b8a80b30b298a00f7358015a55de8fce492626 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 23 Sep 2024 13:41:19 -0700 Subject: [PATCH] add experiment_dataloader helper API --- src/tiledbsoma_ml/__init__.py | 2 + src/tiledbsoma_ml/pytorch.py | 87 +++++++++++++++ tests/test_pytorch.py | 196 +++++++++++++++++++++++++++++++++- 3 files changed, 284 insertions(+), 1 deletion(-) diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index 49e850e..263608f 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -8,6 +8,7 @@ from .pytorch import ( ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterDataPipe, + experiment_dataloader, ) __version__ = "0.1.0-dev" @@ -15,4 +16,5 @@ __all__ = [ "ExperimentAxisQueryIterDataPipe", "ExperimentAxisQueryIterableDataset", + "experiment_dataloader", ] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 32d90b3..8e6c077 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -722,6 +722,73 @@ def shape(self) -> Tuple[int, int]: return self._exp_iter.shape +def experiment_dataloader( + ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, + **dataloader_kwargs: Any, +) -> torch.utils.data.DataLoader: + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` + or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. + + Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, + when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. + Specifying any of these parameters will result in an error. + + Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on + :class:`torch.utils.data.DataLoader` parameters. + + Args: + ds: + A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May + include chained data pipes. + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, + except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not + supported when using data loaders in this module. + + Returns: + A :class:`torch.utils.data.DataLoader`. + + Raises: + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params + are passed as keyword arguments. + + Lifecycle: + experimental + """ + unsupported_dataloader_args = [ + "shuffle", + "batch_size", + "sampler", + "batch_sampler", + ] + if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" + ) + + if dataloader_kwargs.get("num_workers", 0) > 0: + _init_multiprocessing() + + if "collate_fn" not in dataloader_kwargs: + dataloader_kwargs["collate_fn"] = _collate_noop + + return torch.utils.data.DataLoader( + ds, + batch_size=None, # batching is handled by upstream iterator + shuffle=False, # shuffling is handled by upstream iterator + **dataloader_kwargs, + ) + + +def _collate_noop(datum: _T) -> _T: + """Noop collation for use with a dataloader instance. + + Private. + """ + return datum + + def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. @@ -795,3 +862,23 @@ def _get_worker_world_rank() -> Tuple[int, int]: num_workers = worker_info.num_workers worker = worker_info.id return num_workers, worker + + +def _init_multiprocessing() -> None: + """Ensures use of "spawn" for starting child processes with multiprocessing. + + Forked processes are known to be problematic: + https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks + Also, CUDA does not support forked child processes: + https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + + Private. + """ + orig_start_method = torch.multiprocessing.get_start_method() + if orig_start_method != "spawn": + if orig_start_method: + logger.warning( + "switching torch multiprocessing start method from " + f'"{torch.multiprocessing.get_start_method()}" to "spawn"' + ) + torch.multiprocessing.set_start_method("spawn", force=True) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 31703a2..c9679ff 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -6,10 +6,12 @@ from __future__ import annotations import pathlib -from typing import Callable, List, Optional, Sequence, Union +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from unittest.mock import patch import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa import pytest @@ -29,6 +31,7 @@ ExperimentAxisQueryIterable, ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterDataPipe, + experiment_dataloader, ) except ImportError: # this should only occur when not running `ml`-marked tests @@ -466,6 +469,37 @@ def test_batching__partial_soma_batches_are_concatenated( assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), +) +def test_multiprocessing__returns_full_result( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + """Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a + PyTorch DataLoader with multiple workers configured.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + io_batch_size=3, # two chunks, one per worker + ) + # Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing + dl = experiment_dataloader(dp, num_workers=2) + + full_result = list(iter(dl)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + assert sorted(soma_joinids) == list(range(6)) + + @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], @@ -571,6 +605,130 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( assert sorted(soma_joinids) == expected_joinids.tolist() +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), +) +def test_experiment_dataloader__non_batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + assert all(d[0].shape == (3,) for d in data) + assert all(d[1].shape == (1, 1) for d in data) + + row = data[0] + assert row[0].tolist() == [0, 1, 0] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), +) +def test_experiment_dataloader__batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + + batch = data[0] + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].to_numpy().tolist() == [[0], [1], [2]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), +) +def test_experiment_dataloader__batched_length( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + assert len(dl) == len(list(dl)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,batch_size", + [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], +) +@pytest.mark.parametrize( + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), +) +def test_experiment_dataloader__collate_fn( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + batch_size: int, +) -> None: + def collate_fn( + batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] + ) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: + assert isinstance(data, tuple) + assert len(data) == 2 + assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) + if batch_size > 1: + assert data[0].shape[0] == data[1].shape[0] + assert data[0].shape[0] <= batch_size + else: + assert data[0].ndim == 1 + assert data[1].shape[1] <= batch_size + return data + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=batch_size, + ) + dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) + assert len(list(dl)) > 0 + + @pytest.mark.parametrize( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], @@ -592,6 +750,42 @@ def test__X_tensor_dtype_matches_X_matrix( assert data[0].dtype == np.float32 +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] +) +def test__pytorch_splitting( + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterDataPipe( + query, + X_name="raw", + obs_column_names=["label"], + ) + # function not available for IterableDataset, yet.... + dp_train, dp_test = dp.random_split( + weights={"train": 0.7, "test": 0.3}, seed=1234 + ) + dl = experiment_dataloader(dp_train) + + all_rows = list(iter(dl)) + assert len(all_rows) == 7 + + +def test_experiment_dataloader__unsupported_params__fails() -> None: + with patch( + "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" + ) as dummy_exp_data_pipe: + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, shuffle=True) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_size=3) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, sampler=[]) + + def test_batched() -> None: from tiledbsoma_ml.pytorch import _batched