Skip to content

Commit

Permalink
pytorch.py nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Dec 17, 2024
1 parent f039245 commit 16a6090
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
When using this class in any distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
the same ordering will always be used.
In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you
must provide a seed, ensuring that the same shuffle is used across all replicas.
Expand Down Expand Up @@ -251,19 +251,19 @@ def _create_obs_joinids_partition(self) -> Iterator[NDArrayJoinId]:
if self.shuffle:
assert self.io_batch_size % self.shuffle_chunk_size == 0
shuffle_split = np.array_split(
_gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size))
_gpu_split, max(1, ceil(min_len / self.shuffle_chunk_size))
)

# Deterministically create RNG - state must be same across all processes, ensuring
# that the joinid partitions are identical across all processes.
rng = np.random.default_rng(self.seed + self.epoch + 99)
rng.shuffle(shuffle_split)
obs_joinids_chunked = list(
obs_joinids_chunked = [
np.concatenate(b)
for b in _batched(
shuffle_split, self.io_batch_size // self.shuffle_chunk_size
)
)
]
else:
obs_joinids_chunked = np.array_split(
_gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size))
Expand Down Expand Up @@ -463,7 +463,7 @@ def _io_batch_iter(
f"Retrieving next SOMA IO batch of length {len(obs_coords)}..."
)

# to maximize optty's for concurrency, when in eager_fetch mode,
# To maximize opportunities for concurrency, when in eager_fetch mode,
# create the X read iterator first, as the eager iterator will begin
# the read-ahead immediately. Then proceed to fetch obs DataFrame.
# This matters most on latent backing stores, e.g., S3.
Expand Down Expand Up @@ -910,7 +910,7 @@ def epoch(self) -> int:


def experiment_dataloader(
ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset,
ds: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
**dataloader_kwargs: Any,
) -> torch.utils.data.DataLoader:
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_experiment_axis_query_iterable_error_checks(
dp[0]

with pytest.raises(ValueError):
dp = ExperimentAxisQueryIterable(
ExperimentAxisQueryIterable(
query,
obs_column_names=(),
X_name="raw",
Expand Down

0 comments on commit 16a6090

Please sign in to comment.