From 16a6090e9575e690b624d4e8394da77c84381fe3 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Fri, 6 Dec 2024 11:24:25 -0500 Subject: [PATCH] pytorch.py nits --- src/tiledbsoma_ml/pytorch.py | 12 ++++++------ tests/test_pytorch.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index a5561fb..d783ba7 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -175,7 +175,7 @@ def __init__( When using this class in any distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, - the same ordering will be always used. + the same ordering will always be used. In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you must provide a seed, ensuring that the same shuffle is used across all replicas. @@ -251,19 +251,19 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]: if self.shuffle: assert self.io_batch_size % self.shuffle_chunk_size == 0 shuffle_split = np.array_split( - _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) + _gpu_split, max(1, ceil(min_len / self.shuffle_chunk_size)) ) # Deterministically create RNG - state must be same across all processes, ensuring # that the joinid partitions are identical across all processes. rng = np.random.default_rng(self.seed + self.epoch + 99) rng.shuffle(shuffle_split) - obs_joinids_chunked = list( + obs_joinids_chunked = [ np.concatenate(b) for b in _batched( shuffle_split, self.io_batch_size // self.shuffle_chunk_size ) - ) + ] else: obs_joinids_chunked = np.array_split( _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) @@ -463,7 +463,7 @@ def _io_batch_iter( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) - # to maximize optty's for concurrency, when in eager_fetch mode, + # To maximize opportunities for concurrency, when in eager_fetch mode, # create the X read iterator first, as the eager iterator will begin # the read-ahead immediately. Then proceed to fetch obs DataFrame. # This matters most on latent backing stores, e.g., S3. @@ -910,7 +910,7 @@ def epoch(self) -> int: def experiment_dataloader( - ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, + ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, **dataloader_kwargs: Any, ) -> torch.utils.data.DataLoader: """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 05bf6ca..52f7aee 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -776,7 +776,7 @@ def test_experiment_axis_query_iterable_error_checks( dp[0] with pytest.raises(ValueError): - dp = ExperimentAxisQueryIterable( + ExperimentAxisQueryIterable( query, obs_column_names=(), X_name="raw",