Skip to content

Commit

Permalink
fix: remove calls to Pytorch Dataset len (#8647)
Browse files Browse the repository at this point in the history
Pytorch Datasets (torch.utils.data.Dataset) aren't guaranteed to have a __len__ implemented (Datasets can be either "map-style" or "iterable-style". When map-style, they must implement a __len__, and when iterable-style they may). The __len__ on a Pytorch DataLoader may pass the call through to its Dataset.

A det.pytorch.PyTorchTrial is typically constructed from a det.pytorch.DataLoader. det.pytorch.DataLoader cannot, itself, front an iterable-style Pytorch Dataset. It is, however, possible to construct a det.pytorch.PyTorchTrial with an unwrapped torch.utils.data.Dataset if context.experimental.disable_dataset_reproducibility_checks() is called in the PyTorchTrial's __init__.

Before this patch, during a PyTorchTrialContext.run we called len on the trial's validation dataloader. Per the above, it had been possible to construct a trial with a validation dataloader that did not have __len__ implemented, and in this case run would raise a runtime TypeError exception.

Turns out, though, those existing calls to __len__ that weren't actually necessary. This patch revises them with no functional change in behavior

instead of len(validation_loader) to check for emptiness before iterating through it, instead check the number of times the validation_loader is iterated through, raising the same error if it was empty.
removes a call to len where the result was entirely ignored.

This PR also makes a couple "continuous improvement" changes, including moving around a couple pieces of code and renames variables so that its logic is a little more obvious.
  • Loading branch information
wes-turner authored Mar 27, 2024
1 parent d9e1088 commit 8f5de35
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import warnings
from abc import abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

import numpy as np
import torch
import torch.utils.data
from torch import distributed as dist

import determined as det
Expand Down Expand Up @@ -104,8 +105,8 @@ def _from_values(
def should_stop(self, step_num: int) -> bool:
if isinstance(self.value, int):
return self._divides(step_num)
if isinstance(self.value, collections.Container):
return step_num in self.value
assert isinstance(self.value, collections.Container)
return step_num in self.value

def _divides(self, steps: int) -> bool:
assert isinstance(steps, int) and isinstance(
Expand Down Expand Up @@ -343,7 +344,7 @@ def _aggregate_training_metrics(self, training_metrics: List[Dict]) -> Dict:

assert self.state
if self.context.get_enable_tensorboard_logging():
det.pytorch._log_tb_metrics(
pytorch._log_tb_metrics(
self.context.get_tensorboard_writer(),
"train",
self.state.batches_trained,
Expand Down Expand Up @@ -856,7 +857,9 @@ def _should_update_scaler(self) -> bool:
return False
return self.context._should_communicate_and_update()

def _train_batch(self, batch: pytorch.TorchData, epoch_idx: int, batch_idx: int) -> Dict:
def _train_batch(
self, batch: pytorch.TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, Any]:
# Reset loss IDs for AMP
self.context._loss_ids = {}

Expand Down Expand Up @@ -906,11 +909,6 @@ def _train_batch(self, batch: pytorch.TorchData, epoch_idx: int, batch_idx: int)
samples_per_second = self.trial.get_batch_length(batch) / batch_dur
samples_per_second *= self.context.distributed.size

if not isinstance(training_metrics, Dict):
raise TypeError(
f"train_batch() must return a dictionary mapping string names to Tensor metrics, "
f"got {type(training_metrics).__name__}"
)
return training_metrics

@torch.no_grad() # type: ignore
Expand Down Expand Up @@ -939,11 +937,10 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
batch_metrics = []

assert isinstance(self.validation_loader, torch.utils.data.DataLoader)
if len(self.validation_loader) == 0:
raise RuntimeError("validation_loader is empty.")
for callback in self.callbacks.values():
callback.on_validation_epoch_start()

idx = -1 # Later, we'll use this default to see if we've iterated at all.
for idx, batch in enumerate(iter(self.validation_loader)):
if self.context.experimental._auto_to_device:
batch = self.context.to_device(batch)
Expand Down Expand Up @@ -974,6 +971,9 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
if self.test_mode:
break

if idx == -1:
raise RuntimeError("validation_loader is empty.")

for callback in self.callbacks.values():
callback.on_validation_epoch_end(batch_metrics)

Expand All @@ -988,14 +988,10 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic

# Gather a list of per-worker (num_inputs, num_batches) tuples.
input_counts = self.context.distributed.gather((num_inputs, idx + 1))
if self.is_chief:
assert input_counts is not None
# Reshape and sum.
num_inputs, num_batches = [sum(n) for n in zip(*input_counts)]

else:
assert self._evaluate_full_dataset_defined(), "evaluate_full_dataset not defined."
self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader)
assert self.validation_loader is not None
if self.is_chief:
metrics = self.trial.evaluate_full_dataset(data_loader=self.validation_loader)

Expand All @@ -1005,7 +1001,6 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
)

metrics = pytorch._convert_metrics_to_numpy(metrics)
num_inputs = self.context.get_per_slot_batch_size() * len(self.validation_loader)

metrics.update(
pytorch._convert_metrics_to_numpy(self.context.reduce_metrics(for_training=False))
Expand Down Expand Up @@ -1033,12 +1028,17 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
# common than evaluate_batch() and we can't know how the user processed their
# validation data.
if self._evaluate_batch_defined():
# Reshape and sum.
# TODO: remove the type directive once we upgrade to mypy >= 1.7.0
inputs_total, batches_total = [sum(n) for n in zip(*input_counts)] # type: ignore
step_duration = time.time() - step_start_time
logger.info(
det.util.make_timing_log("validated", step_duration, num_inputs, num_batches)
det.util.make_timing_log(
"validated", step_duration, inputs_total, batches_total
)
)
if self.context.get_enable_tensorboard_logging():
det.pytorch._log_tb_metrics(
pytorch._log_tb_metrics(
self.context.get_tensorboard_writer(),
"val",
self.state.batches_trained,
Expand Down Expand Up @@ -1527,20 +1527,30 @@ def train_batch(
pass

@abstractmethod
def build_training_data_loader(self) -> pytorch.DataLoader:
def build_training_data_loader(self) -> Union[pytorch.DataLoader, torch.utils.data.DataLoader]:
"""
Defines the data loader to use during training.
Must return an instance of :py:class:`determined.pytorch.DataLoader`.
Most implementations of :class:`determined.pytorch.PyTorchTrial` will return a
:class:`determined.pytorch.DataLoader` here. Some use cases may not fit the assumptions of
:class:`determined.pytorch.DataLoader`. In that event, a bare
``torch.utils.data.DataLoader`` may be returned if steps in the note atop
:ref:`pytorch-reproducible-dataset` are followed.
"""
pass

@abstractmethod
def build_validation_data_loader(self) -> pytorch.DataLoader:
def build_validation_data_loader(
self,
) -> Union[pytorch.DataLoader, torch.utils.data.DataLoader]:
"""
Defines the data loader to use during validation.
Must return an instance of :py:class:`determined.pytorch.DataLoader`.
Users with a MapDataset will normally return a :class:`determined.pytorch.DataLoader`, but
users with an IterableDataset or with other advanced needs may sacrifice some
Determined-managed functionality (ex: automatic data sharding) to return a bare
:class:`torch.utils.data.DataLoader` following the best-practices described in
:ref:`pytorch-reproducible-dataset`.
"""
pass

Expand Down Expand Up @@ -1610,7 +1620,7 @@ def get_batch_length(self, batch: Any) -> int:
"""Count the number of records in a given batch.
Override this method when you are using custom batch types, as produced
when iterating over the :py:class:`determined.pytorch.DataLoader`.
when iterating over the class:`determined.pytorch.DataLoader`.
For example, when using ``pytorch_geometric``:
.. code-block:: python
Expand Down

0 comments on commit 8f5de35

Please sign in to comment.