Skip to content

Commit

Permalink
Delete TrainingEpochLoop._dataloader_idx which always equals 0 (#8911)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 16, 2021
1 parent 32c7cce commit d0efb55
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 36 deletions.
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def teardown(self) -> None:
"""
self.training_type_plugin.teardown()

def batch_to_device(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None
) -> Any:
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
"""Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
Expand All @@ -171,7 +169,7 @@ def batch_to_device(

if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
# no need to transfer batch to device in DP mode
return model._apply_batch_transfer_handler(batch, device, dataloader_idx)
return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)

return move_data_to_device(batch, device)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def logger(self):
return self.trainer.logger if self.trainer else None

def _apply_batch_transfer_handler(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
) -> Any:
device = device or self.device
batch = self.on_before_batch_transfer(batch, dataloader_idx)
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,12 @@ def optimizer_freq_cumsum(self) -> int:
def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
Args:
batch: the current batch to run the train step on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
"""
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand All @@ -92,13 +91,13 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
return AttributeDict(signal=-1)

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx)
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
if response == -1:
return AttributeDict(signal=-1)

self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

super().run(batch, batch_idx, dataloader_idx)
super().run(batch, batch_idx)
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
self.batch_outputs = None # free memory
return output
Expand All @@ -108,26 +107,24 @@ def reset(self) -> None:
self._hiddens = None
self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]

def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
def on_run_start(self, batch: Any, batch_idx: int):
"""Splits the data into tbptt splits
Args:
batch: the current batch to run the trainstep on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
"""
void(batch_idx, dataloader_idx)
void(batch_idx)
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))

def advance(self, batch, batch_idx, dataloader_idx):
def advance(self, batch, batch_idx):
"""Runs the train step together with optimization (if necessary) on the current batch split
Args:
batch: the current batch to run the training on (this is not the split!)
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
"""
void(batch, dataloader_idx)
void(batch)
split_idx, split_batch = self._remaining_splits.pop(0)
self.split_idx = split_idx

Expand Down
12 changes: 3 additions & 9 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache


class TrainingEpochLoop(loops.Loop):
Expand All @@ -48,8 +47,6 @@ def __init__(self, min_steps: int, max_steps: int):
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
self._dataloader_idx: Optional[int] = None
self._warning_cache: WarningCache = WarningCache()
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None

@property
Expand Down Expand Up @@ -87,7 +84,6 @@ def connect(
def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
self.is_last_batch = False
self._dataloader_idx = 0

# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
Expand Down Expand Up @@ -120,12 +116,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)
batch = self.trainer.accelerator.batch_to_device(batch)

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx, self._dataloader_idx)
batch_output = self.batch_loop.run(batch, self.batch_idx)

self.batch_progress.increment_processed()

Expand All @@ -143,9 +139,7 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

# hook
self.trainer.call_hook(
"on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, self._dataloader_idx
)
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,22 +392,22 @@ class CurrentTestDM(LightningDataModule):
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.targets *= 2
return batch

def transfer_batch_to_device(self, batch, device, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
self.transfer_batch_to_device_hook_rank = self.rank
self.rank += 1
batch.samples = batch.samples.to(device)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,22 @@ class CurrentTestModel(BoringModel):
on_after_batch_transfer_hook_rank = None

def on_before_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
self.on_before_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.samples += 1
return batch

def on_after_batch_transfer(self, batch, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
assert batch.samples.device == batch.targets.device == expected_device
self.on_after_batch_transfer_hook_rank = self.rank
self.rank += 1
batch.targets *= 2
return batch

def transfer_batch_to_device(self, batch, device, dataloader_idx):
assert dataloader_idx is None
assert dataloader_idx == 0
self.transfer_batch_to_device_hook_rank = self.rank
self.rank += 1
batch.samples = batch.samples.to(device)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def backward(self, loss, optimizer, optimizer_idx):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0

train_step_out = out.training_step_output
Expand Down Expand Up @@ -134,7 +134,7 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0

train_step_out = out.training_step_output
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0

train_step_out = out.training_step_output
Expand Down Expand Up @@ -219,7 +219,7 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0

train_step_out = out.training_step_output
Expand Down Expand Up @@ -300,7 +300,7 @@ def training_step(self, batch, batch_idx):

# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.training_step_output == [[]]
assert out.signal == 0
Expand Down Expand Up @@ -344,7 +344,7 @@ def train_dataloader(self):

# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx, 0)
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.training_step_output == [[]]
assert out.signal == 0

0 comments on commit d0efb55

Please sign in to comment.